{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "q4WF3l23pumU"
      },
      "source": [
        "##### Copyright 2018 The AdaNet Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "id": "Kic2quJWppmx"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "aL7SpaKdirqG"
      },
      "source": [
        "# Customizing AdaNet"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "i5s1gsS1bOuB"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/adanet/blob/master/adanet/examples/tutorials/customizing_adanet.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/adanet/blob/master/adanet/examples/tutorials/customizing_adanet.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ySKpPtPrmaCx"
      },
      "source": [
        "Often times, as a researcher or machine learning practitioner, you will have\n",
        "some prior knowledge about a dataset. Ideally you should be able to encode that\n",
        "knowledge into your machine learning algorithm. With `adanet`, you can do so by\n",
        "defining the *neural architecture search space* that the AdaNet algorithm should\n",
        "explore.\n",
        "\n",
        "In this tutorial, we will explore the flexibility of the `adanet` framework, and\n",
        "create a custom search space for an image-classification dataset using high-level\n",
        "TensorFlow libraries like the\n",
        "[`tf.keras.layers`](https://www.tensorflow.org/guide/keras#build_advanced_models)\n",
        "functional API.\n",
        "\n"]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "id": "xB1akik24RFa"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "# If you're running this in Colab, first install the adanet package:\n",
        "!pip install adanet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "id": "x_3b6xx2s6B9"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import\n",
        "from __future__ import division\n",
        "from __future__ import print_function\n",
        "\n",
        "import functools\n",
        "import os\n",
        "import shutil\n",
        "\n",
        "import adanet\n",
        "from adanet.examples import simple_dnn\n",
        "import matplotlib.pyplot as plt\n",
        "import tensorflow as tf\n",
        "\n",
        "\n",
        "# The random seed to use.\n",
        "RANDOM_SEED = 42\n",
        "\n",
        "LOG_DIR = '/tmp/models'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7gE5Mm9j2oYw"
      },
      "source": [
        "## Fashion MNIST dataset\n",
        "\n",
        "In this example, we will use the Fashion MNIST dataset\n",
        "[[Xiao et al., 2017](https://arxiv.org/abs/1708.07747)] for classifying fashion\n",
        "apparel images into one of ten categories:\n",
        "\n",
        "1.  T-shirt/top\n",
        "2.  Trouser\n",
        "3.  Pullover\n",
        "4.  Dress\n",
        "5.  Coat\n",
        "6.  Sandal\n",
        "7.  Shirt\n",
        "8.  Sneaker\n",
        "9.  Bag\n",
        "10. Ankle boot\n",
        "\n",
        "![Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist/blob/master/doc/img/fashion-mnist-sprite.png?raw=true)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5_hRtdchqRZb"
      },
      "source": [
        "## Download the data\n",
        "\n",
        "Conveniently, the data is available via Keras:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 215,
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 1351,
          "status": "ok",
          "timestamp": 1545240446081,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 300
        },
        "id": "uYklOnPJ4h7g",
        "outputId": "4f27aeaa-81e9-4b4e-ddcf-1735528bf13f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz\n",
            "32768/29515 [=================================] - 0s 0us/step\n",
            "40960/29515 [=========================================] - 0s 0us/step\n",
            "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz\n",
            "26427392/26421880 [==============================] - 0s 0us/step\n",
            "26435584/26421880 [==============================] - 0s 0us/step\n",
            "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz\n",
            "16384/5148 [===============================================================================================] - 0s 0us/step\n",
            "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz\n",
            "4423680/4422102 [==============================] - 0s 0us/step\n",
            "4431872/4422102 [==============================] - 0s 0us/step\n"
          ]
        }
      ],
      "source": [
        "(x_train, y_train), (x_test, y_test) = (\n",
        "    tf.keras.datasets.fashion_mnist.load_data())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tECo5dFd4QCa"
      },
      "source": [
        "## Supply the data in TensorFlow\n",
        "\n",
        "Our first task is to supply the data in TensorFlow. Using the\n",
        "tf.estimator.Estimator covention, we will define a function that returns an\n",
        "`input_fn` which returns feature and label `Tensors`.\n",
        "\n",
        "We will also use the `tf.data.Dataset` API to feed the data into our models."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "id": "gxTAoIXwsTH7"
      },
      "outputs": [],
      "source": [
        "FEATURES_KEY = \"images\"\n",
        "\n",
        "\n",
        "def generator(images, labels):\n",
        "  \"\"\"Returns a generator that returns image-label pairs.\"\"\"\n",
        "\n",
        "  def _gen():\n",
        "    for image, label in zip(images, labels):\n",
        "      yield image, label\n",
        "\n",
        "  return _gen\n",
        "\n",
        "\n",
        "def preprocess_image(image, label):\n",
        "  \"\"\"Preprocesses an image for an `Estimator`.\"\"\"\n",
        "  # First let's scale the pixel values to be between 0 and 1.\n",
        "  image = image / 255.\n",
        "  # Next we reshape the image so that we can apply a 2D convolution to it.\n",
        "  image = tf.reshape(image, [28, 28, 1])\n",
        "  # Finally the features need to be supplied as a dictionary.\n",
        "  features = {FEATURES_KEY: image}\n",
        "  return features, label\n",
        "\n",
        "\n",
        "def input_fn(partition, training, batch_size):\n",
        "  \"\"\"Generate an input_fn for the Estimator.\"\"\"\n",
        "\n",
        "  def _input_fn():\n",
        "    if partition == \"train\":\n",
        "      dataset = tf.data.Dataset.from_generator(\n",
        "          generator(x_train, y_train), (tf.float32, tf.int32), ((28, 28), ()))\n",
        "    elif partition == \"predict\":\n",
        "      dataset = tf.data.Dataset.from_generator(\n",
        "          generator(x_test[:10], y_test[:10]), (tf.float32, tf.int32), ((28,28), ()))\n",
        "    else:\n",
        "      dataset = tf.data.Dataset.from_generator(\n",
        "          generator(x_test, y_test), (tf.float32, tf.int32), ((28, 28), ()))\n",
        "\n",
        "    # We call repeat after shuffling, rather than before, to prevent separate\n",
        "    # epochs from blending together.\n",
        "    if training:\n",
        "      dataset = dataset.shuffle(10 * batch_size, seed=RANDOM_SEED).repeat()\n",
        "\n",
        "    dataset = dataset.map(preprocess_image).batch(batch_size)\n",
        "    iterator = dataset.make_one_shot_iterator()\n",
        "    features, labels = iterator.get_next()\n",
        "    return features, labels\n",
        "\n",
        "  return _input_fn"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "fhu2zpf9faIB"
      },
      "source": [
        "## Launch TensorBoard\n",
        "\n",
        "Let's run [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard) to visualize model training over time. We'll use [ngrok](https://ngrok.com/) to tunnel traffic to localhost.\n",
        "\n",
        "*The instructions for setting up Tensorboard were obtained from https://www.dlology.com/blog/quick-guide-to-run-tensorboard-in-google-colab/*\n",
        "\n",
        "Run the next cells and follow the link to see the TensorBoard in a new tab."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "S1UcP5yeaDz9"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "\n",
        "get_ipython().system_raw(\n",
        "    'tensorboard --logdir {} --host 0.0.0.0 --port 6006 \u0026'\n",
        "    .format(LOG_DIR)\n",
        ")\n",
        "\n",
        "# Install ngrok binary.\n",
        "! wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip\n",
        "! unzip ngrok-stable-linux-amd64.zip\n",
        "\n",
        "# Delete old logs dir.\n",
        "shutil.rmtree(LOG_DIR, ignore_errors=True)\n",
        "\n",
        "print(\"Follow this link to open TensorBoard in a new tab.\")\n",
        "get_ipython().system_raw('./ngrok http 6006 \u0026')\n",
        "! curl -s http://localhost:4040/api/tunnels | python3 -c \\\n",
        "    \"import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])\"\n",
        "\n"]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vm9yudEv5lQZ"
      },
      "source": [
        "## Establish baselines\n",
        "\n",
        "The next task should be to get somes baselines to see how our model performs on\n",
        "this dataset.\n",
        "\n",
        "Let's define some information to share with all our `tf.estimator.Estimators`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "id": "xNwSUWh-9_Ib"
      },
      "outputs": [],
      "source": [
        "# The number of classes.\n",
        "NUM_CLASSES = 10\n",
        "\n",
        "# We will average the losses in each mini-batch when computing gradients.\n",
        "loss_reduction = tf.losses.Reduction.SUM_OVER_BATCH_SIZE\n",
        "\n",
        "# A `Head` instance defines the loss function and metrics for `Estimators`.\n",
        "head = tf.contrib.estimator.multi_class_head(\n",
        "    NUM_CLASSES, loss_reduction=loss_reduction)\n",
        "\n",
        "# Some `Estimators` use feature columns for understanding their input features.\n",
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(FEATURES_KEY, shape=[28, 28, 1])\n",
        "]\n",
        "\n",
        "def make_config(experiment_name):\n",
        "  # Estimator configuration.\n",
        "  return tf.estimator.RunConfig(\n",
        "    save_checkpoints_steps=1000,\n",
        "    save_summary_steps=1000,\n",
        "    tf_random_seed=RANDOM_SEED,\n",
        "    model_dir=os.path.join(LOG_DIR, experiment_name))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "QY0cv-ot-Gxs"
      },
      "source": [
        "Let's start simple, and train a linear model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {
          "height": 53,
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 132686,
          "status": "ok",
          "timestamp": 1545240584407,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 300
        },
        "id": "s8wJKsi06blX",
        "outputId": "3e8e638d-7d5e-496c-8508-21fdb9d8a7a0"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpv3_YXD\n",
            "INFO:tensorflow:Using config: {'_save_checkpoints_secs': None, '_num_ps_replicas': 0, '_keep_checkpoint_max': 5, '_task_type': 'worker', '_global_id_in_cluster': 0, '_is_chief': True, '_cluster_spec': \u003ctensorflow.python.training.server_lib.ClusterSpec object at 0x7fb7dd214b90\u003e, '_model_dir': '/tmp/tmpv3_YXD', '_protocol': None, '_save_checkpoints_steps': 50000, '_keep_checkpoint_every_n_hours': 10000, '_service': None, '_session_config': allow_soft_placement: true\n",
            "graph_options {\n",
            "  rewrite_options {\n",
            "    meta_optimizer_iterations: ONE\n",
            "  }\n",
            "}\n",
            ", '_tf_random_seed': 42, '_save_summary_steps': 50000, '_device_fn': None, '_experimental_distribute': None, '_num_worker_replicas': 1, '_task_id': 0, '_log_step_count_steps': 100, '_evaluation_master': '', '_eval_distribute': None, '_train_distribute': None, '_master': ''}\n",
            "INFO:tensorflow:Not using Distribute Coordinator.\n",
            "INFO:tensorflow:Running training and evaluation locally (non-distributed).\n",
            "INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 50000 or save_checkpoints_secs None.\n",
            "INFO:tensorflow:Calling model_fn.\n",
            "INFO:tensorflow:Done calling model_fn.\n",
            "INFO:tensorflow:Create CheckpointSaverHook.\n",
            "INFO:tensorflow:Graph was finalized.\n",
            "INFO:tensorflow:Running local_init_op.\n",
            "INFO:tensorflow:Done running local_init_op.\n",
            "INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpv3_YXD/model.ckpt.\n",
            "INFO:tensorflow:loss = 2.3025854, step = 1\n",
            "INFO:tensorflow:global_step/sec: 78.4217\n",
            "INFO:tensorflow:loss = 1.1483729, step = 101 (1.277 sec)\n",
            "INFO:tensorflow:global_step/sec: 85.4069\n",
            "INFO:tensorflow:loss = 0.5317185, step = 201 (1.171 sec)\n",
            "INFO:tensorflow:global_step/sec: 85.5113\n",
            "INFO:tensorflow:loss = 0.69216, step = 301 (1.169 sec)\n",
            "INFO:tensorflow:global_step/sec: 86.0727\n",
            "INFO:tensorflow:loss = 0.5465001, step = 401 (1.162 sec)\n",
            "INFO:tensorflow:global_step/sec: 87.7462\n",
            "INFO:tensorflow:loss = 0.5904441, step = 501 (1.139 sec)\n",
            "INFO:tensorflow:global_step/sec: 85.2336\n",
            "INFO:tensorflow:loss = 0.50240695, step = 601 (1.173 sec)\n",
            "INFO:tensorflow:global_step/sec: 86.591\n",
            "INFO:tensorflow:loss = 0.45161504, step = 701 (1.155 sec)\n",
            "INFO:tensorflow:global_step/sec: 84.779\n",
            "INFO:tensorflow:loss = 0.55855703, step = 801 (1.179 sec)\n",
            "INFO:tensorflow:global_step/sec: 86.2151\n",
            "INFO:tensorflow:loss = 0.45025244, step = 901 (1.160 sec)\n",
            "INFO:tensorflow:global_step/sec: 85.855\n",
            "INFO:tensorflow:loss = 0.30628046, step = 1001 (1.165 sec)\n",
            "INFO:tensorflow:global_step/sec: 86.2517\n",
            "INFO:tensorflow:loss = 0.36369488, step = 1101 (1.159 sec)\n",
            "INFO:tensorflow:global_step/sec: 87.135\n",
            "INFO:tensorflow:loss = 0.46128386, step = 1201 (1.148 sec)\n",
            "INFO:tensorflow:global_step/sec: 86.0694\n",
            "INFO:tensorflow:loss = 0.37242556, step = 1301 (1.162 sec)\n",
            "INFO:tensorflow:global_step/sec: 86.9772\n",
            "INFO:tensorflow:loss = 0.6352097, step = 1401 (1.150 sec)\n",
            "INFO:tensorflow:global_step/sec: 85.2288\n",
            "INFO:tensorflow:loss = 0.64497614, step = 1501 (1.173 sec)\n",
            "INFO:tensorflow:global_step/sec: 86.1174\n",
            "INFO:tensorflow:loss = 0.44837192, step = 1601 (1.161 sec)\n",
            "INFO:tensorflow:global_step/sec: 85.2724\n",
            "INFO:tensorflow:loss = 0.19677357, step = 1701 (1.173 sec)\n",
            "INFO:tensorflow:global_step/sec: 85.3417\n",
            "INFO:tensorflow:loss = 0.30824977, step = 1801 (1.172 sec)\n",
            "INFO:tensorflow:global_step/sec: 82.9701\n",
            "INFO:tensorflow:loss = 0.41565034, step = 1901 (1.205 sec)\n",
            "INFO:tensorflow:global_step/sec: 86.4188\n",
            "INFO:tensorflow:loss = 0.50474447, step = 2001 (1.157 sec)\n",
            "INFO:tensorflow:global_step/sec: 85.6055\n",
            "INFO:tensorflow:loss = 0.61880505, step = 2101 (1.168 sec)\n",
            "INFO:tensorflow:global_step/sec: 88.3828\n",
            "INFO:tensorflow:loss = 0.26540744, step = 2201 (1.131 sec)\n",
            "INFO:tensorflow:global_step/sec: 85.7678\n",
            "INFO:tensorflow:loss = 0.6249703, step = 2301 (1.170 sec)\n",
            "INFO:tensorflow:global_step/sec: 86.285\n",
            "INFO:tensorflow:loss = 0.4627853, step = 2401 (1.155 sec)\n",
            "INFO:tensorflow:global_step/sec: 87.0864\n",
            "INFO:tensorflow:loss = 0.3652569, step = 2501 (1.149 sec)\n",
            "INFO:tensorflow:global_step/sec: 84.7971\n",
            "INFO:tensorflow:loss = 0.5498886, step = 2601 (1.179 sec)\n",
            "INFO:tensorflow:global_step/sec: 87.2721\n",
            "INFO:tensorflow:loss = 0.8027369, step = 2701 (1.146 sec)\n",
            "INFO:tensorflow:global_step/sec: 79.4374\n",
            "INFO:tensorflow:loss = 0.35317737, step = 2801 (1.259 sec)\n",
            "INFO:tensorflow:global_step/sec: 79.952\n",
            "INFO:tensorflow:loss = 0.40077135, step = 2901 (1.251 sec)\n",
            "INFO:tensorflow:global_step/sec: 79.9757\n",
            "INFO:tensorflow:loss = 0.4046799, step = 3001 (1.251 sec)\n",
            "INFO:tensorflow:global_step/sec: 79.1359\n",
            "INFO:tensorflow:loss = 0.3870783, step = 3101 (1.264 sec)\n",
            "INFO:tensorflow:global_step/sec: 83.3249\n",
            "INFO:tensorflow:loss = 0.48983032, step = 3201 (1.200 sec)\n",
            "INFO:tensorflow:global_step/sec: 85.9455\n",
            "INFO:tensorflow:loss = 0.44130433, step = 3301 (1.164 sec)\n",
            "INFO:tensorflow:global_step/sec: 87.8296\n",
            "INFO:tensorflow:loss = 0.44600248, step = 3401 (1.138 sec)\n",
            "INFO:tensorflow:global_step/sec: 86.9828\n",
            "INFO:tensorflow:loss = 0.39337838, step = 3501 (1.150 sec)\n",
            "INFO:tensorflow:global_step/sec: 86.7191\n",
            "INFO:tensorflow:loss = 0.39527538, step = 3601 (1.153 sec)\n",
            "INFO:tensorflow:global_step/sec: 86.3277\n",
            "INFO:tensorflow:loss = 0.5779325, step = 3701 (1.158 sec)\n",
            "INFO:tensorflow:global_step/sec: 86.8714\n",
            "INFO:tensorflow:loss = 0.34049273, step = 3801 (1.151 sec)\n",
            "INFO:tensorflow:global_step/sec: 88.4383\n",
            "INFO:tensorflow:loss = 0.24868618, step = 3901 (1.131 sec)\n",
            "INFO:tensorflow:global_step/sec: 87.1281\n",
            "INFO:tensorflow:loss = 0.4210478, step = 4001 (1.147 sec)\n",
            "INFO:tensorflow:global_step/sec: 87.4262\n",
            "INFO:tensorflow:loss = 0.5045028, step = 4101 (1.144 sec)\n",
            "INFO:tensorflow:global_step/sec: 86.6544\n",
            "INFO:tensorflow:loss = 0.4103523, step = 4201 (1.154 sec)\n",
            "INFO:tensorflow:global_step/sec: 87.8312\n",
            "INFO:tensorflow:loss = 0.5953512, step = 4301 (1.139 sec)\n",
            "INFO:tensorflow:global_step/sec: 88.4659\n",
            "INFO:tensorflow:loss = 0.23190314, step = 4401 (1.131 sec)\n",
            "INFO:tensorflow:global_step/sec: 86.0797\n",
            "INFO:tensorflow:loss = 0.5709208, step = 4501 (1.161 sec)\n",
            "INFO:tensorflow:global_step/sec: 87.6714\n",
            "INFO:tensorflow:loss = 0.33438712, step = 4601 (1.141 sec)\n",
            "INFO:tensorflow:global_step/sec: 84.8791\n",
            "INFO:tensorflow:loss = 0.76975584, step = 4701 (1.178 sec)\n",
            "INFO:tensorflow:global_step/sec: 88.5043\n",
            "INFO:tensorflow:loss = 0.48181474, step = 4801 (1.130 sec)\n",
            "INFO:tensorflow:global_step/sec: 87.797\n",
            "INFO:tensorflow:loss = 0.30714703, step = 4901 (1.139 sec)\n",
            "INFO:tensorflow:Saving checkpoints for 5000 into /tmp/tmpv3_YXD/model.ckpt.\n",
            "INFO:tensorflow:Calling model_fn.\n",
            "INFO:tensorflow:Done calling model_fn.\n",
            "INFO:tensorflow:Starting evaluation at 2018-12-13-18:44:39\n",
            "INFO:tensorflow:Graph was finalized.\n",
            "INFO:tensorflow:Restoring parameters from /tmp/tmpv3_YXD/model.ckpt-5000\n",
            "INFO:tensorflow:Running local_init_op.\n",
            "INFO:tensorflow:Done running local_init_op.\n",
            "INFO:tensorflow:Finished evaluation at 2018-12-13-18:44:41\n",
            "INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.8413, average_loss = 0.46480885, global_step = 5000, loss = 0.46422938\n",
            "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmp/tmpv3_YXD/model.ckpt-5000\n",
            "INFO:tensorflow:Loss for final step: 0.6494813.\n",
            "Accuracy: 0.8413\n",
            "Loss: 0.46480885\n"
          ]
        }
      ],
      "source": [
        "#@test {\"skip\": true}\n",
        "#@title Parameters\n",
        "LEARNING_RATE = 0.001  #@param {type:\"number\"}\n",
        "TRAIN_STEPS = 5000  #@param {type:\"integer\"}\n",
        "BATCH_SIZE = 64  #@param {type:\"integer\"}\n",
        "\n",
        "estimator = tf.estimator.LinearClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    n_classes=NUM_CLASSES,\n",
        "    optimizer=tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE),\n",
        "    loss_reduction=loss_reduction,\n",
        "    config=make_config(\"linear\"))\n",
        "\n",
        "results, _ = tf.estimator.train_and_evaluate(\n",
        "    estimator,\n",
        "    train_spec=tf.estimator.TrainSpec(\n",
        "        input_fn=input_fn(\"train\", training=True, batch_size=BATCH_SIZE),\n",
        "        max_steps=TRAIN_STEPS),\n",
        "    eval_spec=tf.estimator.EvalSpec(\n",
        "        input_fn=input_fn(\"test\", training=False, batch_size=BATCH_SIZE),\n",
        "        steps=None,\n",
        "        start_delay_secs=1,\n",
        "        throttle_secs=1,  \n",
        "    ))\n",
        "print(\"Accuracy:\", results[\"accuracy\"])\n",
        "print(\"Loss:\", results[\"average_loss\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "a-1hE03c7_Yj"
      },
      "source": [
        "The linear model with default parameters achieves about **84.13% accuracy**.\n",
        "\n",
        "Let's see if we can do better with the `simple_dnn` AdaNet:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 2154,
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 111638,
          "status": "error",
          "timestamp": 1545240696091,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 300
        },
        "id": "9fAoRYd19eUs",
        "outputId": "e842798b-ca2b-46a8-aa23-41c96a92d929"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpBpWYGG\n",
            "INFO:tensorflow:Using config: {'_save_checkpoints_secs': None, '_num_ps_replicas': 0, '_keep_checkpoint_max': 5, '_task_type': 'worker', '_global_id_in_cluster': 0, '_is_chief': True, '_cluster_spec': \u003ctensorflow.python.training.server_lib.ClusterSpec object at 0x7fb7da305b50\u003e, '_model_dir': '/tmp/tmpBpWYGG', '_protocol': None, '_save_checkpoints_steps': 50000, '_keep_checkpoint_every_n_hours': 10000, '_service': None, '_session_config': allow_soft_placement: true\n",
            "graph_options {\n",
            "  rewrite_options {\n",
            "    meta_optimizer_iterations: ONE\n",
            "  }\n",
            "}\n",
            ", '_tf_random_seed': 42, '_save_summary_steps': 50000, '_device_fn': None, '_experimental_distribute': None, '_num_worker_replicas': 1, '_task_id': 0, '_log_step_count_steps': 100, '_evaluation_master': '', '_eval_distribute': None, '_train_distribute': None, '_master': ''}\n",
            "INFO:tensorflow:Not using Distribute Coordinator.\n",
            "INFO:tensorflow:Running training and evaluation locally (non-distributed).\n",
            "INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 50000 or save_checkpoints_secs None.\n",
            "INFO:tensorflow:Beginning training AdaNet iteration 0\n",
            "INFO:tensorflow:Calling model_fn.\n",
            "INFO:tensorflow:Building iteration 0\n",
            "INFO:tensorflow:Building subnetwork 'linear'\n",
            "WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/adanet/examples/simple_dnn.py:104: calling __new__ (from adanet.core.subnetwork.generator) with persisted_tensors is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "`persisted_tensors` is deprecated, please use `shared` instead.\n",
            "INFO:tensorflow:Building subnetwork '1_layer_dnn'\n",
            "INFO:tensorflow:Done calling model_fn.\n",
            "INFO:tensorflow:Create CheckpointSaverHook.\n",
            "INFO:tensorflow:Graph was finalized.\n",
            "INFO:tensorflow:Running local_init_op.\n",
            "INFO:tensorflow:Done running local_init_op.\n",
            "INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpBpWYGG/model.ckpt.\n",
            "INFO:tensorflow:loss = 2.5908535, step = 1\n",
            "INFO:tensorflow:global_step/sec: 62.6364\n",
            "INFO:tensorflow:loss = 0.99051607, step = 101 (1.598 sec)\n",
            "INFO:tensorflow:global_step/sec: 77.613\n",
            "INFO:tensorflow:loss = 0.42087823, step = 201 (1.288 sec)\n",
            "INFO:tensorflow:global_step/sec: 77.5399\n",
            "INFO:tensorflow:loss = 0.54013, step = 301 (1.289 sec)\n",
            "INFO:tensorflow:global_step/sec: 78.6072\n",
            "INFO:tensorflow:loss = 0.4134988, step = 401 (1.272 sec)\n",
            "INFO:tensorflow:global_step/sec: 78.7752\n",
            "INFO:tensorflow:loss = 0.5669622, step = 501 (1.269 sec)\n",
            "INFO:tensorflow:global_step/sec: 80.2865\n",
            "INFO:tensorflow:loss = 0.39677918, step = 601 (1.245 sec)\n",
            "INFO:tensorflow:global_step/sec: 79.4194\n",
            "INFO:tensorflow:loss = 0.35278338, step = 701 (1.259 sec)\n",
            "INFO:tensorflow:global_step/sec: 79.3689\n",
            "INFO:tensorflow:loss = 0.5044695, step = 801 (1.260 sec)\n",
            "INFO:tensorflow:global_step/sec: 80.2831\n",
            "INFO:tensorflow:loss = 0.38145697, step = 901 (1.246 sec)\n",
            "INFO:tensorflow:global_step/sec: 76.9801\n",
            "INFO:tensorflow:loss = 0.20107026, step = 1001 (1.299 sec)\n",
            "INFO:tensorflow:global_step/sec: 78.9748\n",
            "INFO:tensorflow:loss = 0.2736662, step = 1101 (1.266 sec)\n",
            "INFO:tensorflow:global_step/sec: 79.4406\n",
            "INFO:tensorflow:loss = 0.42778087, step = 1201 (1.259 sec)\n",
            "INFO:tensorflow:global_step/sec: 79.3415\n",
            "INFO:tensorflow:loss = 0.38598654, step = 1301 (1.260 sec)\n",
            "INFO:tensorflow:global_step/sec: 75.5153\n",
            "INFO:tensorflow:loss = 0.5578107, step = 1401 (1.324 sec)\n",
            "INFO:tensorflow:global_step/sec: 78.7477\n",
            "INFO:tensorflow:loss = 0.47856787, step = 1501 (1.270 sec)\n",
            "INFO:tensorflow:global_step/sec: 78.5178\n",
            "INFO:tensorflow:loss = 0.32456928, step = 1601 (1.274 sec)\n",
            "INFO:tensorflow:global_step/sec: 80.1592\n",
            "INFO:tensorflow:loss = 0.14113683, step = 1701 (1.248 sec)\n",
            "INFO:tensorflow:global_step/sec: 79.9168\n",
            "INFO:tensorflow:loss = 0.21314168, step = 1801 (1.251 sec)\n",
            "INFO:tensorflow:global_step/sec: 77.9573\n",
            "INFO:tensorflow:loss = 0.32758698, step = 1901 (1.283 sec)\n",
            "INFO:tensorflow:global_step/sec: 79.9252\n",
            "INFO:tensorflow:loss = 0.33568507, step = 2001 (1.251 sec)\n",
            "INFO:tensorflow:global_step/sec: 78.5538\n",
            "INFO:tensorflow:loss = 0.5402397, step = 2101 (1.273 sec)\n",
            "INFO:tensorflow:global_step/sec: 79.6781\n",
            "INFO:tensorflow:loss = 0.18559779, step = 2201 (1.255 sec)\n",
            "INFO:tensorflow:global_step/sec: 79.7851\n",
            "INFO:tensorflow:loss = 0.634346, step = 2301 (1.253 sec)\n",
            "INFO:tensorflow:global_step/sec: 79.0365\n",
            "INFO:tensorflow:loss = 0.3705944, step = 2401 (1.265 sec)\n",
            "INFO:tensorflow:Saving checkpoints for 2500 into /tmp/tmpBpWYGG/model.ckpt.\n",
            "INFO:tensorflow:Calling model_fn.\n",
            "INFO:tensorflow:Building iteration 0\n",
            "INFO:tensorflow:Building subnetwork 'linear'\n",
            "INFO:tensorflow:Building subnetwork '1_layer_dnn'\n",
            "INFO:tensorflow:Done calling model_fn.\n",
            "INFO:tensorflow:Starting evaluation at 2018-12-13-18:45:18\n",
            "INFO:tensorflow:Graph was finalized.\n",
            "INFO:tensorflow:Restoring parameters from /tmp/tmpBpWYGG/model.ckpt-2500\n",
            "INFO:tensorflow:Running local_init_op.\n",
            "INFO:tensorflow:Done running local_init_op.\n",
            "INFO:tensorflow:Saving candidate 't0_linear' dict for global step 2500: accuracy/adanet/adanet_weighted_ensemble = 0.8415, accuracy/adanet/subnetwork = 0.8415, accuracy/adanet/uniform_average_ensemble = 0.8415, architecture/adanet/ensembles = \n",
            "W\n",
            "9adanet/iteration_0/ensemble_t0_linear/architecture/adanetB\u0010\u0008\u0007\u0012\u0000B\n",
            "| linear |J\u0008\n",
            "\u0006\n",
            "\u0004text, average_loss/adanet/adanet_weighted_ensemble = 0.4806687, average_loss/adanet/subnetwork = 0.4806687, average_loss/adanet/uniform_average_ensemble = 0.4806687, loss/adanet/adanet_weighted_ensemble = 0.47982842, loss/adanet/subnetwork = 0.47982842, loss/adanet/uniform_average_ensemble = 0.47982842\n",
            "INFO:tensorflow:Saving candidate 't0_1_layer_dnn' dict for global step 2500: accuracy/adanet/adanet_weighted_ensemble = 0.8566, accuracy/adanet/subnetwork = 0.8566, accuracy/adanet/uniform_average_ensemble = 0.8566, architecture/adanet/ensembles = \n",
            "a\n",
            "\u003eadanet/iteration_0/ensemble_t0_1_layer_dnn/architecture/adanetB\u0015\u0008\u0007\u0012\u0000B\u000f| 1_layer_dnn |J\u0008\n",
            "\u0006\n",
            "\u0004text, average_loss/adanet/adanet_weighted_ensemble = 0.40864596, average_loss/adanet/subnetwork = 0.40864596, average_loss/adanet/uniform_average_ensemble = 0.40864596, loss/adanet/adanet_weighted_ensemble = 0.40789038, loss/adanet/subnetwork = 0.40789038, loss/adanet/uniform_average_ensemble = 0.40789038\n",
            "INFO:tensorflow:Finished evaluation at 2018-12-13-18:45:22\n",
            "INFO:tensorflow:Saving dict for global step 2500: accuracy = 0.8566, accuracy/adanet/adanet_weighted_ensemble = 0.8566, accuracy/adanet/subnetwork = 0.8566, accuracy/adanet/uniform_average_ensemble = 0.8566, average_loss = 0.40864596, average_loss/adanet/adanet_weighted_ensemble = 0.40864596, average_loss/adanet/subnetwork = 0.40864596, average_loss/adanet/uniform_average_ensemble = 0.40864596, global_step = 2500, loss = 0.40789038, loss/adanet/adanet_weighted_ensemble = 0.40789038, loss/adanet/subnetwork = 0.40789038, loss/adanet/uniform_average_ensemble = 0.40789038\n",
            "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 2500: /tmp/tmpBpWYGG/model.ckpt-2500\n",
            "INFO:tensorflow:Loss for final step: 0.3498758.\n",
            "INFO:tensorflow:Finished training Adanet iteration 0\n",
            "INFO:tensorflow:Beginning bookkeeping phase for iteration 0\n",
            "INFO:tensorflow:Building iteration 0\n",
            "INFO:tensorflow:Building subnetwork 'linear'\n",
            "INFO:tensorflow:Building subnetwork '1_layer_dnn'\n",
            "INFO:tensorflow:Starting ensemble evaluation for iteration 0\n",
            "INFO:tensorflow:Restoring parameters from /tmp/tmpBpWYGG/model.ckpt-2500\n",
            "WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/adanet/core/estimator.py:717: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "To construct input pipelines, use the `tf.data` module.\n",
            "WARNING:tensorflow:`tf.train.start_queue_runners()` was called when no queue runners were defined. You can safely remove the call to this deprecated function.\n",
            "INFO:tensorflow:Encountered end of input after 939 evaluations\n",
            "INFO:tensorflow:Computed ensemble metrics: adanet_loss/t0_linear = 0.425563, adanet_loss/t0_1_layer_dnn = 0.343846\n",
            "INFO:tensorflow:Finished ensemble evaluation for iteration 0\n",
            "INFO:tensorflow:'t0_1_layer_dnn' at index 1 is moving onto the next iteration\n",
            "INFO:tensorflow:Importing architecture from /tmp/tmpBpWYGG/architecture-0.txt: ['0:1_layer_dnn'].\n",
            "INFO:tensorflow:Rebuilding iteration 0\n",
            "INFO:tensorflow:Building subnetwork '1_layer_dnn'\n",
            "INFO:tensorflow:Warm-starting from: (u'/tmp/tmpBpWYGG/model.ckpt-2500',)\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/ensemble_t0_1_layer_dnn/weighted_subnetwork_0/logits/mixture_weight; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: global_step; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/candidate_t0_1_layer_dnn/adanet/iteration_0/candidate_t0_1_layer_dnn/adanet_loss/local_step; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/ensemble_t0_1_layer_dnn/bias; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/step; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/ensemble_t0_1_layer_dnn/weighted_subnetwork_0/subnetwork/dense/bias; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/candidate_t0_1_layer_dnn/adanet_loss; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/train_op/is_over/is_over_var_fn/is_over_var; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/ensemble_t0_1_layer_dnn/weighted_subnetwork_0/subnetwork/dense/kernel; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/candidate_t0_1_layer_dnn/adanet/iteration_0/candidate_t0_1_layer_dnn/adanet_loss/biased; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/ensemble_t0_1_layer_dnn/weighted_subnetwork_0/subnetwork/dense_1/kernel; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/ensemble_t0_1_layer_dnn/weighted_subnetwork_0/subnetwork/dense_1/bias; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Building iteration 1\n",
            "INFO:tensorflow:Building subnetwork '1_layer_dnn'\n",
            "INFO:tensorflow:Building subnetwork '2_layer_dnn'\n",
            "INFO:tensorflow:Overwriting checkpoint with new graph for iteration 1 to /tmp/tmpBpWYGG/model.ckpt-2500\n",
            "WARNING:tensorflow:`tf.train.start_queue_runners()` was called when no queue runners were defined. You can safely remove the call to this deprecated function.\n",
            "INFO:tensorflow:Finished bookkeeping phase for iteration 0\n",
            "INFO:tensorflow:Beginning training AdaNet iteration 1\n",
            "INFO:tensorflow:Calling model_fn.\n",
            "INFO:tensorflow:Importing architecture from /tmp/tmpBpWYGG/architecture-0.txt: ['0:1_layer_dnn'].\n",
            "INFO:tensorflow:Rebuilding iteration 0\n",
            "INFO:tensorflow:Building subnetwork '1_layer_dnn'\n",
            "INFO:tensorflow:Building iteration 1\n",
            "INFO:tensorflow:Building subnetwork '1_layer_dnn'\n",
            "INFO:tensorflow:Building subnetwork '2_layer_dnn'\n",
            "INFO:tensorflow:Done calling model_fn.\n",
            "INFO:tensorflow:Create CheckpointSaverHook.\n",
            "INFO:tensorflow:Graph was finalized.\n",
            "INFO:tensorflow:Restoring parameters from /tmp/tmpBpWYGG/increment.ckpt-1\n",
            "INFO:tensorflow:Running local_init_op.\n",
            "INFO:tensorflow:Done running local_init_op.\n",
            "INFO:tensorflow:Saving checkpoints for 2500 into /tmp/tmpBpWYGG/model.ckpt.\n",
            "INFO:tensorflow:loss = 0.24298608, step = 2501\n",
            "INFO:tensorflow:global_step/sec: 56.6617\n",
            "INFO:tensorflow:loss = 0.33409065, step = 2601 (1.766 sec)\n",
            "INFO:tensorflow:global_step/sec: 78.124\n",
            "INFO:tensorflow:loss = 0.20756069, step = 2701 (1.280 sec)\n",
            "INFO:tensorflow:global_step/sec: 77.2965\n",
            "INFO:tensorflow:loss = 0.3423041, step = 2801 (1.294 sec)\n",
            "INFO:tensorflow:global_step/sec: 79.225\n",
            "INFO:tensorflow:loss = 0.38019755, step = 2901 (1.262 sec)\n",
            "INFO:tensorflow:global_step/sec: 78.2797\n",
            "INFO:tensorflow:loss = 0.31078112, step = 3001 (1.278 sec)\n",
            "INFO:tensorflow:global_step/sec: 77.6699\n",
            "INFO:tensorflow:loss = 0.2353632, step = 3101 (1.288 sec)\n",
            "INFO:tensorflow:global_step/sec: 78.9738\n",
            "INFO:tensorflow:loss = 0.25860834, step = 3201 (1.266 sec)\n",
            "INFO:tensorflow:global_step/sec: 79.3627\n",
            "INFO:tensorflow:loss = 0.37581128, step = 3301 (1.260 sec)\n",
            "INFO:tensorflow:global_step/sec: 77.7281\n",
            "INFO:tensorflow:loss = 0.273139, step = 3401 (1.286 sec)\n",
            "INFO:tensorflow:global_step/sec: 77.6977\n",
            "INFO:tensorflow:loss = 0.1418717, step = 3501 (1.288 sec)\n",
            "INFO:tensorflow:global_step/sec: 78.2933\n",
            "INFO:tensorflow:loss = 0.2681304, step = 3601 (1.277 sec)\n",
            "INFO:tensorflow:global_step/sec: 77.4242\n",
            "INFO:tensorflow:loss = 0.3804762, step = 3701 (1.292 sec)\n",
            "INFO:tensorflow:global_step/sec: 76.9244\n",
            "INFO:tensorflow:loss = 0.32000834, step = 3801 (1.300 sec)\n",
            "INFO:tensorflow:global_step/sec: 78.8382\n",
            "INFO:tensorflow:loss = 0.36946106, step = 3901 (1.268 sec)\n",
            "INFO:tensorflow:global_step/sec: 79.5428\n",
            "INFO:tensorflow:loss = 0.38832623, step = 4001 (1.257 sec)\n",
            "INFO:tensorflow:global_step/sec: 71.8387\n",
            "INFO:tensorflow:loss = 0.34555018, step = 4101 (1.392 sec)\n",
            "INFO:tensorflow:global_step/sec: 76.6725\n",
            "INFO:tensorflow:loss = 0.21983796, step = 4201 (1.304 sec)\n",
            "INFO:tensorflow:global_step/sec: 78.3602\n",
            "INFO:tensorflow:loss = 0.17176698, step = 4301 (1.276 sec)\n",
            "INFO:tensorflow:global_step/sec: 76.9497\n",
            "INFO:tensorflow:loss = 0.34545347, step = 4401 (1.300 sec)\n",
            "INFO:tensorflow:global_step/sec: 77.9193\n",
            "INFO:tensorflow:loss = 0.339509, step = 4501 (1.284 sec)\n",
            "INFO:tensorflow:global_step/sec: 77.6756\n",
            "INFO:tensorflow:loss = 0.5344475, step = 4601 (1.287 sec)\n",
            "INFO:tensorflow:global_step/sec: 77.3863\n",
            "INFO:tensorflow:loss = 0.17075369, step = 4701 (1.292 sec)\n",
            "INFO:tensorflow:global_step/sec: 78.0709\n",
            "INFO:tensorflow:loss = 0.53067696, step = 4801 (1.281 sec)\n",
            "INFO:tensorflow:global_step/sec: 77.9967\n",
            "INFO:tensorflow:loss = 0.29755312, step = 4901 (1.282 sec)\n",
            "INFO:tensorflow:Saving checkpoints for 5000 into /tmp/tmpBpWYGG/model.ckpt.\n",
            "INFO:tensorflow:Calling model_fn.\n",
            "INFO:tensorflow:Importing architecture from /tmp/tmpBpWYGG/architecture-0.txt: ['0:1_layer_dnn'].\n",
            "INFO:tensorflow:Rebuilding iteration 0\n",
            "INFO:tensorflow:Building subnetwork '1_layer_dnn'\n",
            "INFO:tensorflow:Building iteration 1\n",
            "INFO:tensorflow:Building subnetwork '1_layer_dnn'\n",
            "INFO:tensorflow:Building subnetwork '2_layer_dnn'\n",
            "INFO:tensorflow:Done calling model_fn.\n",
            "INFO:tensorflow:Starting evaluation at 2018-12-13-18:46:18\n",
            "INFO:tensorflow:Graph was finalized.\n",
            "INFO:tensorflow:Restoring parameters from /tmp/tmpBpWYGG/model.ckpt-5000\n",
            "INFO:tensorflow:Running local_init_op.\n",
            "INFO:tensorflow:Done running local_init_op.\n",
            "INFO:tensorflow:Saving candidate 't0_1_layer_dnn' dict for global step 5000: accuracy/adanet/adanet_weighted_ensemble = 0.8566, accuracy/adanet/subnetwork = 0.8566, accuracy/adanet/uniform_average_ensemble = 0.8566, architecture/adanet/ensembles = \n",
            "a\n",
            "\u003eadanet/iteration_0/ensemble_t0_1_layer_dnn/architecture/adanetB\u0015\u0008\u0007\u0012\u0000B\u000f| 1_layer_dnn |J\u0008\n",
            "\u0006\n",
            "\u0004text, average_loss/adanet/adanet_weighted_ensemble = 0.40864596, average_loss/adanet/subnetwork = 0.40864596, average_loss/adanet/uniform_average_ensemble = 0.40864596, loss/adanet/adanet_weighted_ensemble = 0.40789038, loss/adanet/subnetwork = 0.40789038, loss/adanet/uniform_average_ensemble = 0.40789038\n",
            "INFO:tensorflow:Saving candidate 't1_1_layer_dnn' dict for global step 5000: accuracy/adanet/adanet_weighted_ensemble = 0.8566, accuracy/adanet/subnetwork = 0.8566, accuracy/adanet/uniform_average_ensemble = 0.8566, architecture/adanet/ensembles = \n",
            "o\n",
            "\u003eadanet/iteration_1/ensemble_t1_1_layer_dnn/architecture/adanetB#\u0008\u0007\u0012\u0000B\u001d| 1_layer_dnn | 1_layer_dnn |J\u0008\n",
            "\u0006\n",
            "\u0004text, average_loss/adanet/adanet_weighted_ensemble = 0.40864596, average_loss/adanet/subnetwork = 0.40864596, average_loss/adanet/uniform_average_ensemble = 0.40864596, loss/adanet/adanet_weighted_ensemble = 0.40789038, loss/adanet/subnetwork = 0.40789038, loss/adanet/uniform_average_ensemble = 0.40789038\n",
            "INFO:tensorflow:Saving candidate 't1_2_layer_dnn' dict for global step 5000: accuracy/adanet/adanet_weighted_ensemble = 0.8616, accuracy/adanet/subnetwork = 0.8453, accuracy/adanet/uniform_average_ensemble = 0.8616, architecture/adanet/ensembles = \n",
            "o\n",
            "\u003eadanet/iteration_1/ensemble_t1_2_layer_dnn/architecture/adanetB#\u0008\u0007\u0012\u0000B\u001d| 1_layer_dnn | 2_layer_dnn |J\u0008\n",
            "\u0006\n",
            "\u0004text, average_loss/adanet/adanet_weighted_ensemble = 0.39428768, average_loss/adanet/subnetwork = 0.4410424, average_loss/adanet/uniform_average_ensemble = 0.39428768, loss/adanet/adanet_weighted_ensemble = 0.3936451, loss/adanet/subnetwork = 0.44045082, loss/adanet/uniform_average_ensemble = 0.3936451\n",
            "INFO:tensorflow:Finished evaluation at 2018-12-13-18:46:22\n",
            "INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.8566, accuracy/adanet/adanet_weighted_ensemble = 0.8566, accuracy/adanet/subnetwork = 0.8566, accuracy/adanet/uniform_average_ensemble = 0.8566, average_loss = 0.40864596, average_loss/adanet/adanet_weighted_ensemble = 0.40864596, average_loss/adanet/subnetwork = 0.40864596, average_loss/adanet/uniform_average_ensemble = 0.40864596, global_step = 5000, loss = 0.40789038, loss/adanet/adanet_weighted_ensemble = 0.40789038, loss/adanet/subnetwork = 0.40789038, loss/adanet/uniform_average_ensemble = 0.40789038\n",
            "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmp/tmpBpWYGG/model.ckpt-5000\n",
            "INFO:tensorflow:Loss for final step: 0.21913758.\n",
            "INFO:tensorflow:Finished training Adanet iteration 1\n",
            "Accuracy: 0.8566\n",
            "Loss: 0.40864596\n"
          ]
        }
      ],
      "source": [
        "#@test {\"skip\": true}\n",
        "#@title Parameters\n",
        "LEARNING_RATE = 0.003  #@param {type:\"number\"}\n",
        "TRAIN_STEPS = 5000  #@param {type:\"integer\"}\n",
        "BATCH_SIZE = 64  #@param {type:\"integer\"}\n",
        "ADANET_ITERATIONS = 2  #@param {type:\"integer\"}\n",
        "\n",
        "estimator = adanet.Estimator(\n",
        "    head=head,\n",
        "    subnetwork_generator=simple_dnn.Generator(\n",
        "        feature_columns=feature_columns,\n",
        "        optimizer=tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE),\n",
        "        seed=RANDOM_SEED),\n",
        "    max_iteration_steps=TRAIN_STEPS // ADANET_ITERATIONS,\n",
        "    evaluator=adanet.Evaluator(\n",
        "        input_fn=input_fn(\"train\", training=False, batch_size=BATCH_SIZE),\n",
        "        steps=None),\n",
        "    config=make_config(\"simple_dnn\"))\n",
        "\n",
        "results, _ = tf.estimator.train_and_evaluate(\n",
        "    estimator,\n",
        "    train_spec=tf.estimator.TrainSpec(\n",
        "        input_fn=input_fn(\"train\", training=True, batch_size=BATCH_SIZE),\n",
        "        max_steps=TRAIN_STEPS),\n",
        "    eval_spec=tf.estimator.EvalSpec(\n",
        "        input_fn=input_fn(\"test\", training=False, batch_size=BATCH_SIZE),\n",
        "        steps=None,\n",
        "        start_delay_secs=1,\n",
        "        throttle_secs=1,  \n",
        "    ))\n",
        "print(\"Accuracy:\", results[\"accuracy\"])\n",
        "print(\"Loss:\", results[\"average_loss\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ysWsJ3zXDwNx"
      },
      "source": [
        "The `simple_dnn` AdaNet model with default parameters achieves about **85.66%\n",
        "accuracy**.\n",
        "\n",
        "This improvement can be attributed to `simple_dnn` searching over\n",
        "fully-connected neural networks which have more expressive power than the linear\n",
        "model due to their non-linear activations.\n",
        "\n",
        "Fully-connected layers are permutation invariant to their inputs, meaning that\n",
        "if we consistently swapped two pixels before training, the final model would\n",
        "perform identically. However, there is spatial and locality information in\n",
        "images that we should try to capture. Applying a few convolutions to our inputs\n",
        "will allow us to do so, and that will require defining a custom\n",
        "`adanet.subnetwork.Builder` and `adanet.subnetwork.Generator`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "D3IE6-9vFVlg"
      },
      "source": [
        "## Define a convolutional AdaNet model\n",
        "\n",
        "Creating a new search space for AdaNet to explore is straightforward. There are\n",
        "two abstract classes you need to extend:\n",
        "\n",
        "1.  `adanet.subnetwork.Builder`\n",
        "2.  `adanet.subnetwork.Generator`\n",
        "\n",
        "Similar to the tf.estimator.Estimator `model_fn`, `adanet.subnetwork.Builder`\n",
        "allows you to define your own TensorFlow graph for creating a neural network,\n",
        "and specify the training operations.\n",
        "\n",
        "Below we define one that applies a 2D convolution, max-pooling, and then a\n",
        "fully-connected layer to the images:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "id": "IsYJ97tRwBkt"
      },
      "outputs": [],
      "source": [
        "class SimpleCNNBuilder(adanet.subnetwork.Builder):\n",
        "  \"\"\"Builds a CNN subnetwork for AdaNet.\"\"\"\n",
        "\n",
        "  def __init__(self, learning_rate, max_iteration_steps, seed):\n",
        "    \"\"\"Initializes a `SimpleCNNBuilder`.\n",
        "\n",
        "    Args:\n",
        "      learning_rate: The float learning rate to use.\n",
        "      max_iteration_steps: The number of steps per iteration.\n",
        "      seed: The random seed.\n",
        "\n",
        "    Returns:\n",
        "      An instance of `SimpleCNNBuilder`.\n",
        "    \"\"\"\n",
        "    self._learning_rate = learning_rate\n",
        "    self._max_iteration_steps = max_iteration_steps\n",
        "    self._seed = seed\n",
        "\n",
        "  def build_subnetwork(self,\n",
        "                       features,\n",
        "                       logits_dimension,\n",
        "                       training,\n",
        "                       iteration_step,\n",
        "                       summary,\n",
        "                       previous_ensemble=None):\n",
        "    \"\"\"See `adanet.subnetwork.Builder`.\"\"\"\n",
        "    images = list(features.values())[0]\n",
        "    \n",
        "    # Visualize some of the input images in TensorBoard.\n",
        "    summary.image(\"images\", images)\n",
        "    \n",
        "    kernel_initializer = tf.keras.initializers.he_normal(seed=self._seed)\n",
        "    x = tf.keras.layers.Conv2D(\n",
        "        filters=16,\n",
        "        kernel_size=3,\n",
        "        padding=\"same\",\n",
        "        activation=\"relu\",\n",
        "        kernel_initializer=kernel_initializer)(\n",
        "            images)\n",
        "    x = tf.keras.layers.MaxPool2D(pool_size=2, strides=2)(x)\n",
        "    x = tf.keras.layers.Flatten()(x)\n",
        "    x = tf.keras.layers.Dense(\n",
        "        units=64, activation=\"relu\", kernel_initializer=kernel_initializer)(\n",
        "            x)\n",
        "\n",
        "    # The `Head` passed to adanet.Estimator will apply the softmax activation.\n",
        "    logits = tf.keras.layers.Dense(\n",
        "        units=10, activation=None, kernel_initializer=kernel_initializer)(\n",
        "            x)\n",
        "\n",
        "    # Use a constant complexity measure, since all subnetworks have the same\n",
        "    # architecture and hyperparameters.\n",
        "    complexity = tf.constant(1)\n",
        "\n",
        "    return adanet.Subnetwork(\n",
        "        last_layer=x,\n",
        "        logits=logits,\n",
        "        complexity=complexity,\n",
        "        persisted_tensors={})\n",
        "\n",
        "  def build_subnetwork_train_op(self,\n",
        "                                subnetwork,\n",
        "                                loss,\n",
        "                                var_list,\n",
        "                                labels,\n",
        "                                iteration_step,\n",
        "                                summary,\n",
        "                                previous_ensemble=None):\n",
        "    \"\"\"See `adanet.subnetwork.Builder`.\"\"\"\n",
        "\n",
        "    # Momentum optimizer with cosine learning rate decay works well with CNNs.\n",
        "    learning_rate = tf.train.cosine_decay(\n",
        "        learning_rate=self._learning_rate,\n",
        "        global_step=iteration_step,\n",
        "        decay_steps=self._max_iteration_steps)\n",
        "    optimizer = tf.train.MomentumOptimizer(learning_rate, .9)\n",
        "    # NOTE: The `adanet.Estimator` increments the global step.\n",
        "    return optimizer.minimize(loss=loss, var_list=var_list)\n",
        "\n",
        "  def build_mixture_weights_train_op(self, loss, var_list, logits, labels,\n",
        "                                     iteration_step, summary):\n",
        "    \"\"\"See `adanet.subnetwork.Builder`.\"\"\"\n",
        "    return tf.no_op(\"mixture_weights_train_op\")\n",
        "\n",
        "  @property\n",
        "  def name(self):\n",
        "    \"\"\"See `adanet.subnetwork.Builder`.\"\"\"\n",
        "    return \"simple_cnn\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "OFamPrZHJ5ii"
      },
      "source": [
        "Next, we extend a `adanet.subnetwork.Generator`, which defines the search\n",
        "space of candidate `SimpleCNNBuilders` to consider including the final network.\n",
        "It can create one or more at each iteration with different parameters, and the\n",
        "AdaNet algorithm will select the candidate that best improves the overall neural\n",
        "network's `adanet_loss` on the training set.\n",
        "\n",
        "The one below is very simple: it always creates the same architecture, but gives\n",
        "it a different random seed at each iteration:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "id": "-BAnb_XGwhRy"
      },
      "outputs": [],
      "source": [
        "class SimpleCNNGenerator(adanet.subnetwork.Generator):\n",
        "  \"\"\"Generates a `SimpleCNN` at each iteration.\n",
        "  \"\"\"\n",
        "\n",
        "  def __init__(self, learning_rate, max_iteration_steps, seed=None):\n",
        "    \"\"\"Initializes a `Generator` that builds `SimpleCNNs`.\n",
        "\n",
        "    Args:\n",
        "      learning_rate: The float learning rate to use.\n",
        "      max_iteration_steps: The number of steps per iteration.\n",
        "      seed: The random seed.\n",
        "\n",
        "    Returns:\n",
        "      An instance of `Generator`.\n",
        "    \"\"\"\n",
        "    self._seed = seed\n",
        "    self._dnn_builder_fn = functools.partial(\n",
        "        SimpleCNNBuilder,\n",
        "        learning_rate=learning_rate,\n",
        "        max_iteration_steps=max_iteration_steps)\n",
        "\n",
        "  def generate_candidates(self, previous_ensemble, iteration_number,\n",
        "                          previous_ensemble_reports, all_reports):\n",
        "    \"\"\"See `adanet.subnetwork.Generator`.\"\"\"\n",
        "    seed = self._seed\n",
        "    # Change the seed according to the iteration so that each subnetwork\n",
        "    # learns something different.\n",
        "    if seed is not None:\n",
        "      seed += iteration_number\n",
        "    return [self._dnn_builder_fn(seed=seed)]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8sdvharsLJ1T"
      },
      "source": [
        "With these defined, we pass them into a new `adanet.Estimator`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 2658,
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 11332,
          "status": "error",
          "timestamp": 1545240713586,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 300
        },
        "id": "-Fhi1SjkzVBt",
        "outputId": "e61c742c-41a6-4b93-91fe-bcd074c878b1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp26gkPN\n",
            "INFO:tensorflow:Using config: {'_save_checkpoints_secs': None, '_num_ps_replicas': 0, '_keep_checkpoint_max': 5, '_task_type': 'worker', '_global_id_in_cluster': 0, '_is_chief': True, '_cluster_spec': \u003ctensorflow.python.training.server_lib.ClusterSpec object at 0x7fb7cbc989d0\u003e, '_model_dir': '/tmp/tmp26gkPN', '_protocol': None, '_save_checkpoints_steps': 50000, '_keep_checkpoint_every_n_hours': 10000, '_service': None, '_session_config': allow_soft_placement: true\n",
            "graph_options {\n",
            "  rewrite_options {\n",
            "    meta_optimizer_iterations: ONE\n",
            "  }\n",
            "}\n",
            ", '_tf_random_seed': 42, '_save_summary_steps': 50000, '_device_fn': None, '_experimental_distribute': None, '_num_worker_replicas': 1, '_task_id': 0, '_log_step_count_steps': 100, '_evaluation_master': '', '_eval_distribute': None, '_train_distribute': None, '_master': ''}\n",
            "INFO:tensorflow:Not using Distribute Coordinator.\n",
            "INFO:tensorflow:Running training and evaluation locally (non-distributed).\n",
            "INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps 50000 or save_checkpoints_secs None.\n",
            "INFO:tensorflow:Beginning training AdaNet iteration 0\n",
            "INFO:tensorflow:Calling model_fn.\n",
            "INFO:tensorflow:Building iteration 0\n",
            "INFO:tensorflow:Building subnetwork 'simple_cnn'\n",
            "INFO:tensorflow:Done calling model_fn.\n",
            "INFO:tensorflow:Create CheckpointSaverHook.\n",
            "INFO:tensorflow:Graph was finalized.\n",
            "INFO:tensorflow:Running local_init_op.\n",
            "INFO:tensorflow:Done running local_init_op.\n",
            "INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp26gkPN/model.ckpt.\n",
            "INFO:tensorflow:loss = 2.694181, step = 1\n",
            "INFO:tensorflow:global_step/sec: 26.4719\n",
            "INFO:tensorflow:loss = 0.46676728, step = 101 (3.779 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.0665\n",
            "INFO:tensorflow:loss = 0.24282512, step = 201 (3.563 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.2582\n",
            "INFO:tensorflow:loss = 0.37682933, step = 301 (3.539 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.239\n",
            "INFO:tensorflow:loss = 0.32417423, step = 401 (3.541 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.3116\n",
            "INFO:tensorflow:loss = 0.36117983, step = 501 (3.532 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.2307\n",
            "INFO:tensorflow:loss = 0.15681608, step = 601 (3.542 sec)\n",
            "INFO:tensorflow:global_step/sec: 27.5131\n",
            "INFO:tensorflow:loss = 0.34024912, step = 701 (3.635 sec)\n",
            "INFO:tensorflow:global_step/sec: 27.7937\n",
            "INFO:tensorflow:loss = 0.3881727, step = 801 (3.598 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.2994\n",
            "INFO:tensorflow:loss = 0.20572703, step = 901 (3.534 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.211\n",
            "INFO:tensorflow:loss = 0.16611394, step = 1001 (3.545 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.2366\n",
            "INFO:tensorflow:loss = 0.1561893, step = 1101 (3.541 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.3356\n",
            "INFO:tensorflow:loss = 0.27579066, step = 1201 (3.529 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.263\n",
            "INFO:tensorflow:loss = 0.2819903, step = 1301 (3.538 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.2554\n",
            "INFO:tensorflow:loss = 0.23625316, step = 1401 (3.539 sec)\n",
            "INFO:tensorflow:global_step/sec: 27.9935\n",
            "INFO:tensorflow:loss = 0.2637699, step = 1501 (3.572 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.1246\n",
            "INFO:tensorflow:loss = 0.21633779, step = 1601 (3.555 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.1458\n",
            "INFO:tensorflow:loss = 0.09392404, step = 1701 (3.553 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.1103\n",
            "INFO:tensorflow:loss = 0.11175862, step = 1801 (3.557 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.2291\n",
            "INFO:tensorflow:loss = 0.16241878, step = 1901 (3.543 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.2085\n",
            "INFO:tensorflow:loss = 0.2617167, step = 2001 (3.545 sec)\n",
            "INFO:tensorflow:global_step/sec: 28.2504\n",
            "INFO:tensorflow:loss = 0.17909388, step = 2101 (3.540 sec)\n",
            "INFO:tensorflow:global_step/sec: 26.4995\n",
            "INFO:tensorflow:loss = 0.14431182, step = 2201 (3.774 sec)\n",
            "INFO:tensorflow:global_step/sec: 26.1445\n",
            "INFO:tensorflow:loss = 0.41680542, step = 2301 (3.825 sec)\n",
            "INFO:tensorflow:global_step/sec: 25.8637\n",
            "INFO:tensorflow:loss = 0.12174833, step = 2401 (3.866 sec)\n",
            "INFO:tensorflow:Saving checkpoints for 2500 into /tmp/tmp26gkPN/model.ckpt.\n",
            "INFO:tensorflow:Calling model_fn.\n",
            "INFO:tensorflow:Building iteration 0\n",
            "INFO:tensorflow:Building subnetwork 'simple_cnn'\n",
            "INFO:tensorflow:Done calling model_fn.\n",
            "INFO:tensorflow:Starting evaluation at 2018-12-13-18:48:13\n",
            "INFO:tensorflow:Graph was finalized.\n",
            "INFO:tensorflow:Restoring parameters from /tmp/tmp26gkPN/model.ckpt-2500\n",
            "INFO:tensorflow:Running local_init_op.\n",
            "INFO:tensorflow:Done running local_init_op.\n",
            "INFO:tensorflow:Saving candidate 't0_simple_cnn' dict for global step 2500: accuracy/adanet/adanet_weighted_ensemble = 0.9017, accuracy/adanet/subnetwork = 0.9017, accuracy/adanet/uniform_average_ensemble = 0.9017, architecture/adanet/ensembles = \n",
            "_\n",
            "=adanet/iteration_0/ensemble_t0_simple_cnn/architecture/adanetB\u0014\u0008\u0007\u0012\u0000B\u000e| simple_cnn |J\u0008\n",
            "\u0006\n",
            "\u0004text, average_loss/adanet/adanet_weighted_ensemble = 0.27215192, average_loss/adanet/subnetwork = 0.27215192, average_loss/adanet/uniform_average_ensemble = 0.27215192, loss/adanet/adanet_weighted_ensemble = 0.2720527, loss/adanet/subnetwork = 0.2720527, loss/adanet/uniform_average_ensemble = 0.2720527\n",
            "INFO:tensorflow:Finished evaluation at 2018-12-13-18:48:18\n",
            "INFO:tensorflow:Saving dict for global step 2500: accuracy = 0.9017, accuracy/adanet/adanet_weighted_ensemble = 0.9017, accuracy/adanet/subnetwork = 0.9017, accuracy/adanet/uniform_average_ensemble = 0.9017, average_loss = 0.27215192, average_loss/adanet/adanet_weighted_ensemble = 0.27215192, average_loss/adanet/subnetwork = 0.27215192, average_loss/adanet/uniform_average_ensemble = 0.27215192, global_step = 2500, loss = 0.2720527, loss/adanet/adanet_weighted_ensemble = 0.2720527, loss/adanet/subnetwork = 0.2720527, loss/adanet/uniform_average_ensemble = 0.2720527\n",
            "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 2500: /tmp/tmp26gkPN/model.ckpt-2500\n",
            "INFO:tensorflow:Loss for final step: 0.24984711.\n",
            "INFO:tensorflow:Finished training Adanet iteration 0\n",
            "INFO:tensorflow:Beginning bookkeeping phase for iteration 0\n",
            "INFO:tensorflow:Building iteration 0\n",
            "INFO:tensorflow:Building subnetwork 'simple_cnn'\n",
            "INFO:tensorflow:As the only candidate, 't0_simple_cnn' is moving onto the next iteration.\n",
            "INFO:tensorflow:Importing architecture from /tmp/tmp26gkPN/architecture-0.txt: ['0:simple_cnn'].\n",
            "INFO:tensorflow:Rebuilding iteration 0\n",
            "INFO:tensorflow:Building subnetwork 'simple_cnn'\n",
            "INFO:tensorflow:Warm-starting from: (u'/tmp/tmp26gkPN/model.ckpt-2500',)\n",
            "INFO:tensorflow:Warm-starting variable: global_step; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/ensemble_t0_simple_cnn/weighted_subnetwork_0/subnetwork/dense_1/kernel; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/candidate_t0_simple_cnn/adanet/iteration_0/candidate_t0_simple_cnn/adanet_loss/biased; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/ensemble_t0_simple_cnn/weighted_subnetwork_0/subnetwork/conv2d/kernel; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/ensemble_t0_simple_cnn/weighted_subnetwork_0/subnetwork/dense/kernel; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/ensemble_t0_simple_cnn/weighted_subnetwork_0/subnetwork/conv2d/bias; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/ensemble_t0_simple_cnn/bias; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/train_op/is_over/is_over_var_fn/is_over_var; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/step; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/candidate_t0_simple_cnn/adanet/iteration_0/candidate_t0_simple_cnn/adanet_loss/local_step; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/candidate_t0_simple_cnn/adanet_loss; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/ensemble_t0_simple_cnn/weighted_subnetwork_0/subnetwork/dense_1/bias; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/ensemble_t0_simple_cnn/weighted_subnetwork_0/logits/mixture_weight; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Warm-starting variable: adanet/iteration_0/ensemble_t0_simple_cnn/weighted_subnetwork_0/subnetwork/dense/bias; prev_var_name: Unchanged\n",
            "INFO:tensorflow:Building iteration 1\n",
            "INFO:tensorflow:Building subnetwork 'simple_cnn'\n",
            "INFO:tensorflow:Overwriting checkpoint with new graph for iteration 1 to /tmp/tmp26gkPN/model.ckpt-2500\n",
            "WARNING:tensorflow:`tf.train.start_queue_runners()` was called when no queue runners were defined. You can safely remove the call to this deprecated function.\n",
            "INFO:tensorflow:Finished bookkeeping phase for iteration 0\n",
            "INFO:tensorflow:Beginning training AdaNet iteration 1\n",
            "INFO:tensorflow:Calling model_fn.\n",
            "INFO:tensorflow:Importing architecture from /tmp/tmp26gkPN/architecture-0.txt: ['0:simple_cnn'].\n",
            "INFO:tensorflow:Rebuilding iteration 0\n",
            "INFO:tensorflow:Building subnetwork 'simple_cnn'\n",
            "INFO:tensorflow:Building iteration 1\n",
            "INFO:tensorflow:Building subnetwork 'simple_cnn'\n",
            "INFO:tensorflow:Done calling model_fn.\n",
            "INFO:tensorflow:Create CheckpointSaverHook.\n",
            "INFO:tensorflow:Graph was finalized.\n",
            "INFO:tensorflow:Restoring parameters from /tmp/tmp26gkPN/increment.ckpt-1\n",
            "INFO:tensorflow:Running local_init_op.\n",
            "INFO:tensorflow:Done running local_init_op.\n",
            "INFO:tensorflow:Saving checkpoints for 2500 into /tmp/tmp26gkPN/model.ckpt.\n",
            "INFO:tensorflow:loss = 0.151891, step = 2501\n",
            "INFO:tensorflow:global_step/sec: 20.6697\n",
            "INFO:tensorflow:loss = 0.2222729, step = 2601 (4.839 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.256\n",
            "INFO:tensorflow:loss = 0.09993973, step = 2701 (4.493 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.1772\n",
            "INFO:tensorflow:loss = 0.20959951, step = 2801 (4.509 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.4273\n",
            "INFO:tensorflow:loss = 0.29977563, step = 2901 (4.459 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.4256\n",
            "INFO:tensorflow:loss = 0.12138975, step = 3001 (4.459 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.4122\n",
            "INFO:tensorflow:loss = 0.059896674, step = 3101 (4.462 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.3829\n",
            "INFO:tensorflow:loss = 0.24690302, step = 3201 (4.467 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.2736\n",
            "INFO:tensorflow:loss = 0.30908775, step = 3301 (4.490 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.3572\n",
            "INFO:tensorflow:loss = 0.12569772, step = 3401 (4.473 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.0862\n",
            "INFO:tensorflow:loss = 0.11210311, step = 3501 (4.528 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.2895\n",
            "INFO:tensorflow:loss = 0.07889283, step = 3601 (4.486 sec)\n",
            "INFO:tensorflow:global_step/sec: 21.9683\n",
            "INFO:tensorflow:loss = 0.19373977, step = 3701 (4.552 sec)\n",
            "INFO:tensorflow:global_step/sec: 21.1165\n",
            "INFO:tensorflow:loss = 0.21523649, step = 3801 (4.736 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.3396\n",
            "INFO:tensorflow:loss = 0.19285265, step = 3901 (4.476 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.4277\n",
            "INFO:tensorflow:loss = 0.19397877, step = 4001 (4.459 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.386\n",
            "INFO:tensorflow:loss = 0.18937594, step = 4101 (4.467 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.4573\n",
            "INFO:tensorflow:loss = 0.0852247, step = 4201 (4.453 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.3419\n",
            "INFO:tensorflow:loss = 0.10364103, step = 4301 (4.476 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.2885\n",
            "INFO:tensorflow:loss = 0.14810272, step = 4401 (4.487 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.323\n",
            "INFO:tensorflow:loss = 0.20514096, step = 4501 (4.480 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.4002\n",
            "INFO:tensorflow:loss = 0.16804385, step = 4601 (4.464 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.2988\n",
            "INFO:tensorflow:loss = 0.13350865, step = 4701 (4.485 sec)\n",
            "INFO:tensorflow:global_step/sec: 21.9738\n",
            "INFO:tensorflow:loss = 0.3896428, step = 4801 (4.551 sec)\n",
            "INFO:tensorflow:global_step/sec: 22.1545\n",
            "INFO:tensorflow:loss = 0.11354022, step = 4901 (4.513 sec)\n",
            "INFO:tensorflow:Saving checkpoints for 5000 into /tmp/tmp26gkPN/model.ckpt.\n",
            "INFO:tensorflow:Calling model_fn.\n",
            "INFO:tensorflow:Importing architecture from /tmp/tmp26gkPN/architecture-0.txt: ['0:simple_cnn'].\n",
            "INFO:tensorflow:Rebuilding iteration 0\n",
            "INFO:tensorflow:Building subnetwork 'simple_cnn'\n",
            "INFO:tensorflow:Building iteration 1\n",
            "INFO:tensorflow:Building subnetwork 'simple_cnn'\n",
            "INFO:tensorflow:Done calling model_fn.\n",
            "INFO:tensorflow:Starting evaluation at 2018-12-13-18:50:19\n",
            "INFO:tensorflow:Graph was finalized.\n",
            "INFO:tensorflow:Restoring parameters from /tmp/tmp26gkPN/model.ckpt-5000\n",
            "INFO:tensorflow:Running local_init_op.\n",
            "INFO:tensorflow:Done running local_init_op.\n",
            "INFO:tensorflow:Saving candidate 't0_simple_cnn' dict for global step 5000: accuracy/adanet/adanet_weighted_ensemble = 0.9017, accuracy/adanet/subnetwork = 0.9017, accuracy/adanet/uniform_average_ensemble = 0.9017, architecture/adanet/ensembles = \n",
            "_\n",
            "=adanet/iteration_0/ensemble_t0_simple_cnn/architecture/adanetB\u0014\u0008\u0007\u0012\u0000B\u000e| simple_cnn |J\u0008\n",
            "\u0006\n",
            "\u0004text, average_loss/adanet/adanet_weighted_ensemble = 0.27215192, average_loss/adanet/subnetwork = 0.27215192, average_loss/adanet/uniform_average_ensemble = 0.27215192, loss/adanet/adanet_weighted_ensemble = 0.2720527, loss/adanet/subnetwork = 0.2720527, loss/adanet/uniform_average_ensemble = 0.2720527\n",
            "INFO:tensorflow:Saving candidate 't1_simple_cnn' dict for global step 5000: accuracy/adanet/adanet_weighted_ensemble = 0.9046, accuracy/adanet/subnetwork = 0.9005, accuracy/adanet/uniform_average_ensemble = 0.9046, architecture/adanet/ensembles = \n",
            "l\n",
            "=adanet/iteration_1/ensemble_t1_simple_cnn/architecture/adanetB!\u0008\u0007\u0012\u0000B\u001b| simple_cnn | simple_cnn |J\u0008\n",
            "\u0006\n",
            "\u0004text, average_loss/adanet/adanet_weighted_ensemble = 0.26144233, average_loss/adanet/subnetwork = 0.26964897, average_loss/adanet/uniform_average_ensemble = 0.26144233, loss/adanet/adanet_weighted_ensemble = 0.26114514, loss/adanet/subnetwork = 0.26923987, loss/adanet/uniform_average_ensemble = 0.26114514\n",
            "INFO:tensorflow:Finished evaluation at 2018-12-13-18:50:26\n",
            "INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.9046, accuracy/adanet/adanet_weighted_ensemble = 0.9046, accuracy/adanet/subnetwork = 0.9005, accuracy/adanet/uniform_average_ensemble = 0.9046, average_loss = 0.26144233, average_loss/adanet/adanet_weighted_ensemble = 0.26144233, average_loss/adanet/subnetwork = 0.26964897, average_loss/adanet/uniform_average_ensemble = 0.26144233, global_step = 5000, loss = 0.26114514, loss/adanet/adanet_weighted_ensemble = 0.26114514, loss/adanet/subnetwork = 0.26923987, loss/adanet/uniform_average_ensemble = 0.26114514\n",
            "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmp/tmp26gkPN/model.ckpt-5000\n",
            "INFO:tensorflow:Loss for final step: 0.24514686.\n",
            "INFO:tensorflow:Finished training Adanet iteration 1\n",
            "Accuracy: 0.9046\n",
            "Loss: 0.26144233\n"
          ]
        }
      ],
      "source": [
        "#@title Parameters\n",
        "LEARNING_RATE = 0.05  #@param {type:\"number\"}\n",
        "TRAIN_STEPS = 5000  #@param {type:\"integer\"}\n",
        "BATCH_SIZE = 64  #@param {type:\"integer\"}\n",
        "ADANET_ITERATIONS = 2  #@param {type:\"integer\"}\n",
        "\n",
        "max_iteration_steps = TRAIN_STEPS // ADANET_ITERATIONS\n",
        "estimator = adanet.Estimator(\n",
        "    head=head,\n",
        "    subnetwork_generator=SimpleCNNGenerator(\n",
        "        learning_rate=LEARNING_RATE,\n",
        "        max_iteration_steps=max_iteration_steps,\n",
        "        seed=RANDOM_SEED),\n",
        "    max_iteration_steps=max_iteration_steps,\n",
        "    evaluator=adanet.Evaluator(\n",
        "        input_fn=input_fn(\"train\", training=False, batch_size=BATCH_SIZE),\n",
        "        steps=None),\n",
        "    adanet_loss_decay=.99,\n",
        "    config=make_config(\"simple_cnn\"))\n",
        "\n",
        "results, _ = tf.estimator.train_and_evaluate(\n",
        "    estimator,\n",
        "    train_spec=tf.estimator.TrainSpec(\n",
        "        input_fn=input_fn(\"train\", training=True, batch_size=BATCH_SIZE),\n",
        "        max_steps=TRAIN_STEPS),\n",
        "    eval_spec=tf.estimator.EvalSpec(\n",
        "        input_fn=input_fn(\"test\", training=False, batch_size=BATCH_SIZE),\n",
        "        steps=None,\n",
        "        start_delay_secs=1,\n",
        "        throttle_secs=1,  \n",
        "    ))\n",
        "print(\"Accuracy:\", results[\"accuracy\"])\n",
        "print(\"Loss:\", results[\"average_loss\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "3wGtI-4_LRw1"
      },
      "source": [
        "Our `SimpleCNNGenerator` code achieves **90.46% accuracy**."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "EeQG9tuW4RF8"
      },
      "source": [
        "## Generating predictions on our trained model\n",
        "\n",
        "Now that we've got a trained model, we can use it to generate predictions on new input. To keep things simple, here we'll generate predictions on our `estimator` using the first 10 examples from the test set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1801,
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "id": "dzBtgkgm4RF8",
        "outputId": "72536c35-8ca3-4fb7-c372-5736ba4815a9"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Calling model_fn.\n",
            "INFO:tensorflow:Importing architecture from /tmp/tmp26gkPN/architecture-0.txt: ['0:simple_cnn'].\n",
            "INFO:tensorflow:Rebuilding iteration 0\n",
            "INFO:tensorflow:Building subnetwork 'simple_cnn'\n",
            "INFO:tensorflow:Building iteration 1\n",
            "INFO:tensorflow:Building subnetwork 'simple_cnn'\n",
            "INFO:tensorflow:Done calling model_fn.\n",
            "INFO:tensorflow:Graph was finalized.\n",
            "INFO:tensorflow:Restoring parameters from /tmp/tmp26gkPN/model.ckpt-5000\n",
            "INFO:tensorflow:Running local_init_op.\n",
            "INFO:tensorflow:Done running local_init_op.\n"
          ]
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAFcAAABYCAYAAACAnmu5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAACRlJREFUeJztnFtIFW0Xx39by1KzTNFKpeicWAaF\nhGZaUoE3naAog24qLaiwLoqiAxEdMPAqyAqsi4oEL7ywQCvLilQiMNJO2knTskwzMzvonvdif2tG\n53W79+fnY22/+d1MezMz+3H1f9az1nrWjE3TNA0LJXj96QEMZizjKsQyrkIs4yrEMq5CLOMqZEhf\nLzx27BiPHj3CZrOxb98+oqOj+3NcgwOtD5SVlWmpqamapmladXW1tmbNmr7cZtDTJ7dQUlLC4sWL\nAZg8eTItLS18+/atX//TBwN9Mm5jYyOjR4/WPwcFBfHp06d+G9RgoV8WNM3KoHukT8YNDQ2lsbFR\n//zx40dCQkL6bVCDhT4Zd/78+RQUFABQWVlJaGgoI0aM6NeBDQb6FIrNmTOHqKgo1q5di81m49Ch\nQ/09rkGBTRsEDlP+BJvNpn/38+dPAJ49ewbA7Nmze7xGjl5evU9iTdOw2Ww9/pYzrAxNIX3O0P4m\nuqqpqakJgPPnzwPg5+fX7ejj4wPAhAkT9Gt6upd5Qouy5Xy73d7t+56wlKuQQaVcgNLSUgDy8/MB\nmDhxIgA/fvwAoK2tDYCxY8cCsG7dOgD8/f0BQ5lmRf/69QsfHx9+//4NwNChQ12Oy1KuQgaFcr29\nvfV/37lzB4AnT54A6EoTH7lixQrAUR8BOHDgAOCI3QFmzpwJQEREBADPnz8H4P79+2zbto2nT58C\nMG3aNACGDx/udFyWchXi0XGuOeasrKwkLS0NcKTkAMOGDQO6qxtg4cKFAEyfPr3beXLPuro6wIgu\n4uPjiY+PJzU1FYBdu3YBMGPGDKfjs5SrEI9SrrOhinKXLl1KZWVlj9fI6i4KFSRKEGWL7xVFynV5\neXnk5+fr8fHbt29djtdSrkI8Klpwlc+HhIToq3dAQAAA379/BxxxKsDXr18B8PX1BaC1tRUwlHv1\n6lUACgsLAejs7ASgvr4ecMwOd7GUqxCPUq4r2tradKXJceTIkQB6MV+OEq+KYsU3y3Wi8CFDHCaS\nGsKrV6/cHo+lXIV4lHKd1WDFn1ZVVenVL/G9UlOQz7JjIttUYWFhgKHU9vZ2AH0D9vPnz4AjzgXD\nh9fU1AAwfvx4p+O1lKsQj1KuuZYq3Lp1C3CoSZQo1S/xqS0tLcC/lSxKlPhXZoFcL5mebGXNnTsX\nMHxzb1jKVYhHZWjOqv9v3rwBYN68eXr8aj7XHN+OGzcOMPba5Chxr7lVwN/fn5KSEoqLiwFITEx0\nOd4/4hbMC5P5KFPZnDQ421KJiYkBHImDLFgy/eUeYsyOjg7AmPbmdFgKNTIGOV+K8JKcuIPlFhQy\noMo1T1V3tqe7UlVVBcCVK1cAKCoqAoziS1hYmK5YKZJLEiDJhChRFjJpIJQxmZtbJDST7y9fvgw4\nejdcYSlXIQOqXGc+U9Qh4ZKU896/fw/ApUuXAHjw4AFgbJObU9z6+nqmTJkCGGoWJdfW1gKGTxWf\nm5ycDBgKzsvLAwyfK8mE+OabN2+6//e6fabFf82AhmJS9Ni7dy8A7969A6ChoQEwCtPiL8eMGQMY\nahNlm8MtWcGjo6PJysoC0JuzpUnkw4cPgOG3BSmKf/nyBYDAwEDAmA0SmrW0tFBXV6fPCFF+b1jK\nVciA+Fy73Y6XlxebN28G4OXLl44f/89KLooVtQjii+U8c4wp3eyy/X306FHdHx85cgQwCivy/erV\nqwHH4wYAL168AIwNSZkV4qvFr8sYpZnEHSzlKsQtn5uRkcHDhw/p6OggLS2NWbNmsXv3bjo7OwkJ\nCeHkyZO6X+yJ69evs2TJEr2oIu2czc3N3Y6iFkGKKOI3RW1Tp04FjLKfqKu2tlZv9pB0VlJjSX8l\n07p9+zZg+G0p5Mi9zGOx2+1UV1frG5SPHz8G/j3buuLSLZSWllJVVUVOTg7Nzc2sXLmS2NhYUlJS\nSE5OJjMzk9zcXFJSUlzd6v8Ol8aNiYnRH+AbOXIk7e3tlJWVcfjwYQAWLVpEdnZ2r8aVIog0YEih\nWrIe8WNmBct5EjVERkYCRtQgPlhWcB8fH+Li4gBji7yiogIw/LPEq8HBwd0+i18XBYvyzdtA5qaR\n/0m53t7e+mKQm5tLQkIC9+7d091AcHCwy8ekxKhSd/VkxBW5g9vRwo0bN8jNzSU7O7vb9rI7YXJT\nUxPh4eEkJSUBRoVKsiJRQWhoKACjRo0CjHhXjubyoShdVGa32/V6hWzPyOwQlcsskdqCrAMSDZgj\nGIlYampqKC4uJiEhAYAtW7YA9Dpj3YoW7t69S1ZWFufOnSMgIAA/Pz996jY0NOhGseiOS+W2traS\nkZHBhQsX9OwlLi6OgoICli9fTmFhIQsWLOj1HuHh4QCsX78egMzMTMBY9aOiogDD34mizQ3LoiKp\nbIm7EpXZbDbdB06aNAkwfKYoUiIQWQfEf8tvSy1BjuL+5H5yvfxNveHSuNeuXaO5uZn09HT9uxMn\nTrB//35ycnIICwvTe14tuvNHtnnKy8sBR0YFRiwq2ZTMEFGdxJ6iGlGufN+1lVTOET8s6pej+c+V\nzxK/CnK9+PnXr19TXl7O1q1bATh9+rTLv9PK0BQyIMqVB+ScIQ/i7dixAzDquZKZSRYlSpXowRyD\nRkRE6L8jbfVyjkQN5i1xOV/8tsTM8pvLli0DHOuD+F13sZSrkL96a12SE6m1SqwqjRoSs0okEBQU\nNNBD7BVLuQr5q5Xr6VjKVYhlXIVYxlWIZVyFWMZViGVchVjGVYhlXIVYxlWIZVyFWMZVyID0innq\ni4zNnUZFRUVUVlbqOyUbN27UX4rRI/3+Rl4Tnvoi45KSEm3Tpk2apmlaU1OTlpiYqO3Zs0crKipy\n+x7KlevsRcZ/+4s1e+o0cufBvq4o97me+iLjnjqNvL29uXjxIhs2bGDnzp36NpQzBvw5NM3Dysdd\nO40qKioIDAwkMjKSs2fPcurUKQ4ePOj0WuXK9eQXGZs7jWJjY/VmwKSkJL1x2hnKjeupLzKWTqMz\nZ87o0cH27dv1p4LKysr0jiFnKHcLnvoi4546jVatWkV6ejq+vr74+flx/PjxXu9h7aEpxMrQFGIZ\nVyGWcRViGVchlnEVYhlXIZZxFfIPqzqcpS3v/SMAAAAASUVORK5CYII=\n",
            "text/plain": [
              "\u003cmatplotlib.figure.Figure at 0x7fb7d6558f50\u003e"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Predicted class: 9, confidence: 98.987%\n",
            "Actual class: 9 \n",
            "\n",
            "\n"
          ]
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAFcAAABYCAYAAACAnmu5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAACQ5JREFUeJztnE1IFH8Yxz+rrqWlWaaUFUEvBzE7\nSAYmWdoLRJfqoOLBi9ELVFhEQkhdosQgKCMqQYmiMIQORWCUUBnmIcNwL1lEL/Sippmambs7/4M8\nM7vjri/r/mr1P5/LsuPMb2Z+fvc7z/PMM2PTNE3DQglh//oApjPW5CrEmlyFWJOrEGtyFWJNrkIi\nAt3w9OnTtLS0YLPZOH78OKtXrw7mcU0PtABoamrS9uzZo2mapr1580bLzc0NZJhpT0DKbWxsZPPm\nzQAsX76cnp4e+vr6mD179mT/0QDYbLZR12tvbwfgwIEDABQWFgKQmppKRMTwKYWHhwPw9u1bAG7e\nvAnAypUrAdi3bx8AM2fOnNQxj0ZAk9vZ2UlKSor+fd68eXR0dEx6cseaVCExMRGA27dvj7nuggUL\nAMjMzAz8wAIkYM/1RAswgx5LqZ8+fQKgpqYGgKqqKgDsdjsAP378AGBgYAAwFO2L1NRUwFB0S0sL\nAElJSQDk5eUBcPToUQAWLlw44fMxE1C0kJiYSGdnp/69vb2dhISESR/MdMOmBSC75uZmKioqqK6u\nxuFwcOrUKW7dujXpgxkcHARg//79ADx79gwAl8sFwNy5cwGIiYkBYMaMGYChxj9//tDR0QFAXFwc\nAGFhYV6fZvr6+rw+5Vewfft2AM6fPx/w+QRkC2lpaaSkpJCfn4/NZuPkyZMBH8B0JiDlqiI3NxcA\nh8MBwKJFiwBvZQJERkYCI73e5XLpaha1e/5tNMz+L1FGY2MjYFxEJ4KVoSkkKNHCZPn48SNgKHbJ\nkiWAoVSn0wlAb28vAO/evQOgv78fMFQZGRnJ0NAQgB7viiJlLIk05syZA8CKFSv0bT2R7aurqwEo\nKSmZ8HlZylVISCj3yZMngHGl/vXrF2B4rahRkpTr168DRiwqftjR0cH8+fMBcLvdgKFAUb/so7m5\nGUC/GC9evNhrX7LvyspKwFJuyBES0cLWrVsBeP/+PTCcTgNERUUB0N3dDaAnKg8ePACgtbUVgC9f\nvgCwZcsW7t27B0BycjJgxM7mWFmUKml8fHw8YMS74sGi8M+fP+u/ivFiKVchIeG5DQ0NwHCFDQx/\nFO8VPFNuQK8hixeXlpbq3lhUVATAtWvXvMbMysoC4PHjx4ChUKlTmKtqEk28evWKnJycCZ2XpVyF\n/FPlildKWVBiUPFHc3wrGZt5e9muq6uLQ4cOea1z6dIlwIh329ravP4uypTlolj5nDVrFgB1dXWW\nckOJf6rcsrIywPBWqWSJEkWx4qnijx8+fACgp6cHMPzS6XTy/ft3wPBOqTVI/Prz50/AqLh9+/bN\nax9SVROlS1ws14WJYClXIf9UueJhX79+BeDly5eAEddKzClRgahx2bJlgFGj9fRJ8WtRqihQ1pHM\nTWoLaWlpgLf6wfB9ueeWn58/4fOzlKuQkMjQhN+/fwOGD1ZUVABw9+5dwMi6xBelpiBZmKjOF3Ka\nokiJAmSstWvXAnDhwoVgnApgKVcpIZGhCdJDsHTpUgCOHDkCwJ07dwDjLoFcwSVaEMWKr4LhrWbP\nFR+XfUlNeKIx7HiwlKuQkFCuqEvU5qlAMKpkZoWa+x3cbrffu7xmzP4s+/Acy3Mf421Y8SQkJlcO\n3DwxUuKTUqOEV9HR0T6393VtNv/DZFu5CArmbiEZa7z/LF9YtqCQkFCuYL69LWmwFM0lTZY0WG46\nem5nvpBJ6CUKFOXKxVDWn4xC/WEpVyEhpVzzRUO+my9g8mlu9LDb7bovC2YFm8f0p9xALmBmLOUq\nJKSU6w8pZEtRXcIoUZuoUlQ4GrKulCJlm/FsO1Es5SpkSihXSo2CpL8STXhGCOaIw5z+SnwrUYOo\n3+zVwcBSrkLGpdzy8nJevHiB0+lk7969pKamcuzYMVwuFwkJCZw9e3ZEI1swkexJvFb25SsdlmVS\nmJHvEhOLomNjY732Ibd/gsmYk/v8+XPa2tqoqamhu7ubnTt3kpGRQUFBAdu2bePcuXPU1tZSUFAQ\n9IOb6oxpC+np6XrremxsLAMDAzQ1NbFp0yYAsrOz9QZhZQcZFuYzg3K73SOu8i6XC5fL5eW/MKxY\nm82mL5cx7XY7drud/v5+vfzouf5kGFO54eHhuvnX1taSlZVFQ0OD/tOMj4/Xq/mTxd/JPHr0KCjj\nT4RgJBHjjhYePnxIbW0tVVVVeuMcBP6YlC/8PTqVnZ3t9V3UKn7qmXWJd8ptHLMvy3dpP5XHsaQw\nL/Y23gcOR2Nc0cLTp0+5fPkylZWVxMTEEB0d7XW/K5DnBf4PjDm5vb29lJeXc+XKFb1pY926ddTV\n1QHD7Zzr168PysGM5XNDQ0MMDQ3pviqIf3r+ipxOp1dBXMYWn5YxIiIiiIiIGLF+MBjTFu7fv093\ndzfFxcX6srKyMkpLS6mpqSEpKYkdO3YE9aCmC2NObl5env7opifyIEYw8edzErNKrGrGs71fsjlR\ntvmWkUQdsi9zW38wsTI0hUyJ2oIgajQ/BiXLJb4F//fVZLmv+DjYWMpVSEgp11+kIA/9SYOeJDDm\nRuXBwcERy8RjRe1m35a/mz3XuhMR4oSUcv0h7Z1yx1bUJ2374p8ul2tEP4IgtV9RqLTrS0b3+vVr\nr/X/WoZmERghpVx/almzZg0Aq1atAozWI7N/ut1uvU5rjhbMj0CJb0ubv/kdOJbnhjgh1fw83bCU\nqxBrchViTa5CrMlViDW5CrEmVyHW5Crkr2RoU/VFxuZOo/r6ehwOh34vsaioiI0bN/ofQME7eb2Y\nqi8ybmxs1Hbv3q1pmqZ1dXVpGzZs0EpKSrT6+vpxj6FcuapeZKya9PR0/RcmnUYTvVuh3HM7Ozv1\nNyKB8SLjUMdXp1F4eDg3btygsLCQw4cP09XVNeoYf70qpk2xUoZnp1FraytxcXEkJydz9epVLl68\nyIkTJ/xuq1y5U/lFxuZOo4yMDP3J+ZycnBEFdjPKJzczM1PvznE4HCQmJoa834LvTqODBw/qL/Vs\namrSX3ThD+W2MFVfZOyr02jXrl0UFxcTFRVFdHQ0Z86cGXUMq56rECtDU4g1uQqxJlch1uQqxJpc\nhViTqxBrchXyH8EWroM5cjHEAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "\u003cmatplotlib.figure.Figure at 0x7fb7d7810bd0\u003e"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Predicted class: 2, confidence: 99.971%\n",
            "Actual class: 2 \n",
            "\n",
            "\n"
          ]
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAFcAAABYCAYAAACAnmu5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAABvFJREFUeJztnE1oE1sUx39xYrXxq6a0aHVlUShF\nF0IXsX4rgiu1iJQuulHqSqgbRfFjIahUEEQXtYIrN4W4FRTJQoUYwYViFn7hwirWxkZbtWo6mbfo\nO5N0bNqYl1PflPuDMs00d3I4+ffMueeeuQHHcRwMKsz62wbMZIxzFTHOVcQ4VxHjXEWMcxUJljrw\n7NmzPHnyhEAgwPHjx1mzZk057ZoZOCWQSCScjo4Ox3Ec59WrV86+fftKucyMpyTlxuNxtm/fDkB9\nfT1fvnzh69evzJ8/v6xfvJdv374BcPPmTQBaWloAivrcd+/eAfDmzRsAIpEIAJZlld1OoSTnplIp\nGhsb3dfhcJiBgQF1586bNw+A9vb2Px67bNmyccfpoOSYm4+jPIP+9esXALFYDICTJ08CEAqFAKir\nqwOgoqKCiooKAIaHhwH48eMHAM+fPwdyX8zmzZuBnII1KClbqK2tJZVKua8/fvxITU1N2YyaKZSk\n3ObmZi5fvkxrayvJZJLa2lrVkCBqXLRoEQA9PT0AHDt2DIBHjx4B8P79e1ep1dXVACxcuBCAvXv3\nArBr1y4gF781Kcm5a9eupbGxkdbWVgKBAKdPny63XTOCgKMdMMvIvXv3gJwqP336BMCZM2cA+PDh\nA9+/fwfGQhfA+vXrAejo6AByWUM4HAZQzc/NDE2RsmQL04VkB4ODgwAsXboUgO7ubmBMuf39/QCs\nWLECyMXpz58/AzA6OgroZzhglKuKr5QbDI43V9QohMNhNyUcGRkBxtJEgFmzxnQUCATGHTUxylXE\nV8rNZrNATnVSF7BtG4ChoaGCYyXGylgZo4lRriK+Uq7MqqTWMGfOHGC8ouX3QtmA/P3nz5+qtoJR\nriq+Uq7ESVGlHPOV6z0nryXTkGuYmOtzfKVcUWNlZeW41/kq9a4sePPZuXPnapvpYpSriC+VK0jc\nlNmXbdu/vUeQmoLUhqUGoYlRriK+UK6sLngrWhJPC6kVcqqWsZIby2xOznvrFuXAKFcRXyhXMgC5\n0xeqaOWf9+a7gsRpWYnQUKxglKuIL5QriOpmz579x2NF1VKXmA584VxvEUZuUhPdyCb7W/55CQfe\nm2M5MWFBEV8oV9Il7xKNqE5ueNlstmBBxqtQb9HcpGI+wxfKFbylRm+cLGa53PsezSV2o1xFfKHc\nyaa3+eQXy4VCU2WJtcVeuxSMchXxhXIFyQq82UK++rxxuFBxx7tQKQWdcmKUq0hRyu3q6uLx48eM\njo5y8OBBVq9ezZEjR7Btm5qaGi5cuOAWoTXwxsep4upkyAxOjplMpmx2epnSuQ8fPuTly5f09vaS\nTqfZs2cPkUiEtrY2du7cycWLF4lGo7S1takZ6VemDAtNTU1cunQJGGuBHxkZIZFIsG3bNgC2bNlC\nPB5XNTKbzZLNZnEcB8dxsG0b27YJBALjfiYjGAwSDAaxLAvLstyHU+SaGkypXMuy3L7YaDTKxo0b\nefDggRsGqqurGRgYUDFOWLx4MYD7hfqForOFu3fvEo1GuX79Ojt27HDPT0cTsXx5T58+BXJF86kq\nYPl4W0glS6ivrwdg+fLlZbT4388s5k3379+nu7uba9eusWDBAkKhkLuu1d/f7z5/YBjPlM4dHh6m\nq6uLq1evUlVVBcC6deu4ffs2AHfu3GHDhg2qRmYyGTKZjBtbJX4WE2sFiduCXKOvr4++vj4Vu6cM\nC7du3SKdTtPZ2emeO3/+PCdOnKC3t5e6ujp2796tYpzf8cWjUhJr5alNaWfyztQkg8hH1OrNb+Uo\nN8tVq1aV3W4zQ1PEF7UFmffL4qKkht6WUsuyfluJkMVMb+uTNFIvWbJEzW6jXEV8oVx5YO/169dA\n7hEpUbDgOI67FlYoi5DHU6XCZpTrU3yRLQiyCpxMJgFIp9NAbuMK27Z/W8X1NpLIxhey00kpDSbF\nYpSriC+UO1W9Vh5FHRoacrMAea/ksbLZhlfZpuPGp/hCuX7FKFcR41xFjHMVMc5VxDhXEeNcRYxz\nFZmWqphfNzL2dhrFYjGSyaS7lrh//353w80JUdmVNw+/bmQcj8edAwcOOI7jOIODg86mTZuco0eP\nOrFYrOhrqCv3b21k/F9pampy/8Ok0+hPN8BQj7mpVMotnkBuI+P/OxN1GlmWxY0bN2hvb+fw4cPu\nznyFmPaVCMdnpYz8TqNnz55RVVVFQ0MDPT09XLlyhVOnThUcq65cP29k7O00ikQiNDQ0ALB161Ze\nvHgx6Xh15zY3N7vdOdOxkXG5mKjT6NChQ7x9+xaARCLBypUrJ72Geljw60bGE3UatbS00NnZSWVl\nJaFQiHPnzk16DVPPVcTM0BQxzlXEOFcR41xFjHMVMc5VxDhXkX8AA1aJx3F85IYAAAAASUVORK5C\nYII=\n",
            "text/plain": [
              "\u003cmatplotlib.figure.Figure at 0x7fb7d6d23310\u003e"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Predicted class: 1, confidence: 100.0%\n",
            "Actual class: 1 \n",
            "\n",
            "\n"
          ]
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAFcAAABYCAYAAACAnmu5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAACJtJREFUeJztnEtIFf0bxz9qNzXLSi26GpZlNypI\nUulKFF3otkhr0aJCV4FBJJQUbboouKlFF7AWEVluC7RoE2UWRlJGpESWRXbRzGulzbvw/53Rqd56\ndX79jzKfzeGMZ8aZ53zPM8/tN0GWZVn4GCH4/30C/RnfuAbxjWsQ37gG8Y1rEN+4BhnQ0x2PHDlC\neXk5QUFB7N+/nzlz5nh5Xv0DqweUlpZa6enplmVZVlVVlbVly5aeHKbf0yPllpSUsGLFCgDi4uJo\naGigqamJoUOHevrFi0+fPgFQVlYGwMWLFwGIiooCYMeOHQDExsby9u1bAK5cuQLA1atXARg2bBgA\nGRkZAKxZs8bIuXajJ99Idna2df36dfv91q1brefPn3v2jfcXeuxzXV+QF4exKS8vB+Do0aMAhIaG\nAvD161cAhgwZAkBjYyMAt2/fBuDNmzdMnToVgIEDBwIwZcoUACIjIwH48uULAFVVVQCsX78egIMH\nD3p6DdDDaCEmJoYPHz7Y79+9e0d0dLRnJ9Vf6JFyU1JSOHHiBGlpaVRUVBATE+OJv62rqwPg3Llz\nAMybNw+A5uZmAL5//w5AcHCnJgYM6Dz9iIgI+xj6m16lWClZ+yxatAiAly9fAnD8+HEAsrKyen0d\nokfGnT9/PjNnziQtLY2goCAOHTrk2Qn1J4Isrx1mL5Dfk4+V+uRbW1tbAUd94eHhgKPKyMhI+zPa\nVz5WhISEdDtGWFgYAA8fPgQgPT0dgLlz5/b6evwMzSABpdzq6moAcnNzgc4bJ8CoUaMA+Pz5M+Ao\nVQwePBiAjx8/2tvka6VUN9qnvr6+23Yvfa6vXIN4Eud6xaRJkwBYunQpAJcvXwYgOTkZgG/fvgFO\n9KAMTSocPXq07UObmpq67TNixAgAO4MTOtaBAwc8vhpfuUYJKJ/rJiEhAYB169YBncoEJxJQvUD+\nFRylSs16r+xOila9YtWqVUBneOk1vnINElA+152B3b9/H4DDhw93+5yyQUUN8pvh4eG0t7cD2K+K\nmXVs0dHRAZhRrPCVa5CAUq4UK6TQadOmAVBZWQk4ahw+fDjgxLJhYWG2IvW3N2/eAI7KpeC4uDgz\nF9EFX7kGCSjl/gqpTTUGKbytrQ1wooW2tjY7zlW0IFRLEOPGjTN3wv/DV65BAlK5Cr2DgoIAmDx5\nMuD00KRkdSSk5JaWFtv/qjqmypkUXVtbCzgxs3BHKl7gK9cgAalcNxMmTAAcdalGq1ZTfHw80BkR\nSJlqOylKkMp1jF9Vy7zEV65B+oRyded3+0O9V2zb0tJib1MV7N27dwA0NDR021cZnEl85RokIJWr\nKEHIP+oOLyW72/lRUVG2b1UMPGbMGMBRsCppf4OANK47FGtpaQGc1rvCq/fv33fbb/To0XYRRyXF\nQYMGdfuMbmivX78GYPr06YC3IZjw3YJBAlK5bregIoza3Sq6qPCtBKGmpsZ2GRpr0t+k+vHjxwPO\nMIhJfOUaJCCV6+bevXuAU3pU0UUJgtQZHx9v+1wpWAWbrm13cHxu10I7/Ojve4OvXIMEVIPSrRoN\ngeTl5QGOct1FGDUym5ubefr0KeAMkihqcCNly5+npqZ6eCWd+Mo1SED5XLefu3PnDuCMkireVWqr\nto8ShcrKSjvhiI2NBZwmp/y0kgkNlLx48QJwfLIU7wW+cg3yR8rNycmhrKyM9vZ2MjIymD17Nvv2\n7aOjo4Po6Ghyc3N/yIS8QFFCYmIi4BRoVHJUiis0AAI/DkHLT0upag3pVf7bS+X+1rh3796lsrKS\ngoIC6uvr2bRpE0lJSWzbto3Vq1eTl5dHYWEh27Zt8+yk+gu/Ne6CBQvsBXzDhg2jtbWV0tJSe1Bj\n2bJl5Ofne2pc3eGVTcnX6s4uhSrO7apgxbW/Gn5WPKuW+8SJE4Ef42BP+C9Lfy5dumTt3bvXWrhw\nob2turraSk1N9WJlUb/jj6OFGzduUFhYSH5+PitXruz65Xj+hSsKULSgMqIWlkjJ8p9acFhdXc2j\nR48A7HMsLi4GYMaMGYCTkT158qTbdvleHUuDJ73hj6KFW7ducerUKc6ePUtERARhYWH2T7G2ttae\nAPfpzm+V29jYSE5ODufPn7e/3eTkZIqKitiwYQPFxcX2siOvUFSgVylWQyEaB1WWpfg4ODjY/ox8\nr1QvlaekpACdv0Rw/Lr8uFryXij3t8a9du0a9fX1ZGZm2tuOHTtGdnY2BQUFjB07lo0bN/b6RPoj\nvzVuamrqT/NuLcQzgRaB6E6vbKqmpgb4MVPrGj1IudomBT548ABwlqNqX/0PRQv6tXiBn6EZJKBq\nC0L5v/pdUpkULVUqRu26qERxrHvYWUjJI0eOBJx4WOOq6lh4sZbZV65BAlK5quNKTe6BDoWBigQ0\n4FFbW2u33+WXVTPQq/y2FOsewFMc7AW+cg0SkMpV9qQOrrs+oDu6IgP50cTERDu7kx9eu3Yt4ChU\nr/Lf+nVofmHWrFmeXYevXIMEVA9NuP2g5eqt6Y6ujFE+uutiv0DAV65BAlK5/QVfuQbxjWsQ37gG\n8Y1rEN+4BvGNaxDfuAb5K7WFvvogY/ek0c2bN6moqLAzwZ07d9oPO/oppnv3ffVBxiUlJdauXbss\ny7Ksuro6a8mSJVZWVpZ18+bNPz6GceX+7QcZe8XPJo3+a3/NuM/98OGD3aaBzvaKe4lTIBISEmIP\n7xUWFrJ48WJCQkK4cOEC27dvZ8+ePXYB6Vf89Xqu1cdKGV0njR4/fkxkZCQJCQmcOXOGkydP/utD\njY0rty8/yNg9aZSUlGQvEVi+fDnPnj371/2NGzclJYWioiIATx9kbBpNGp0+fdqODnbv3s2rV68A\nKC0ttTslv8K4W+irDzL+2aTR5s2byczMJDQ0lLCwMPuZ6r/Cr+caxM/QDOIb1yC+cQ3iG9cgvnEN\n4hvXIL5xDfIPUWKcGjEJA8gAAAAASUVORK5CYII=\n",
            "text/plain": [
              "\u003cmatplotlib.figure.Figure at 0x7fb7ccfa0590\u003e"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Predicted class: 1, confidence: 100.0%\n",
            "Actual class: 1 \n",
            "\n",
            "\n"
          ]
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAFcAAABYCAYAAACAnmu5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAB9hJREFUeJztnFtIVF0Ux39HTcepKdO0MoqwDOxG\nRT6odI+ipy7gBR98KeopMIiCiHzqgkF0e8iC6qEXQaqXBCWMMDCJhEgjMirsqqmTWo3mzJzvwW+d\no2PzGfPNnu87sn8v09njGZer/6yz9tprb8M0TRONEuL+awMmM9q5CtHOVYh2rkK0cxWinauQhEhv\nPHXqFM+ePcMwDI4dO8bKlSujadfkwIyA5uZmc//+/aZpmubr16/NoqKiSD5m0hNRWGhqamLr1q0A\nLFq0iL6+Pr5//x7V//RoMzAwwMDAQEx/Z0TO7e7uZubMmdZ1amoqX79+jZpRKvB4PHg8npj+zohj\n7mhMxTPowsJCADo7OwGYO3cuAFVVVQCkpKSMu0dUWlBQAGB9s5YuXQrA7du3AUhMTFRldmTKzcjI\noLu727ru6uoiPT09akZNFiJSbkFBAZcuXaKkpIS2tjYyMjKYNm1atG2z8Pv9ALx//x6A9vZ2ABYs\nWADYyt27dy8XL14EIBAIAJCcnAzArFmzAOjp6QHUKlaIyLlr1qxh2bJllJSUYBgGFRUV0bZrUhBx\nzD18+HA07fhHRHVxcSNRbM6cOYAdgz98+ABARUUFLS0tADQ2NgJY4Wp4eBiAefPmxchqPUNTiiOc\nm52dTXZ2Nn6/H7/fT0JCAgkJCbhcLlwuF6ZpWhlLVlYWWVlZeL1evF4viYmJJCYmMjQ0xNDQEP39\n/fT398fEbkc416lEJc9VzerVqwE7a5DYO2PGDMB+8n/69Mm6x+12AxAMBsfcm5qaGgOLR9DKVYgj\nlCszsilTpgB2Djs0NATA/PnzgZG8V9S8cOFCwFasKFjejwVauQpxhHLT0tIASEpKAmwVxsfHA1iz\nw7Vr11pqFnVLnvvz509AfR1kNFq5CnGEcqVUKDWFVatWATB16lTAVuOvX7+se0Tl8p7E61jUFASt\nXIU4QrmhFTeJuaHxFewcWGoJki2IqnWeO0lwhHIFl8s15towjDHXcXFxlnIlk0hIGPkTJfbqPHeS\n4CjliirDjQeDQUvNMiaKlfFYLlJq5SrEUcqVLCHcuGQGYGcQoliJvZ8/f1Zp4hi0chXiKOWG1gXk\nWl4DgYA1AxM1S34r42/evImJreAw54aGBfnqjx6Xf8sDTUKFhIWXL18qt1PQYUEhjlCudPfIV1xU\nOXraK9ehoUJeJSx8/PhRvcF/o5WrEEco9927d4C9dCPlRCncCIFAIGxclgVLWcSUB1tWVpYSm0Er\nVymOUG5tbS1gTwhEnaFTW9M0x02RJVuQn8nJyQGwGvbOnz+vzG6tXIU4QrkPHz4E7FgrhfBQ5Y6e\n/goSc3/8+AHYsbe+vl6hxSNo5Srkj5RbWVnJ06dP8fv9HDhwgBUrVnDkyBECgQDp6emcPXtW6cLf\nixcvAHuZXGJuaLE8EAiMGxNE1bLELm2noTO4aDLhJz5+/Jj29naqq6vxer3s3r2bvLw8SktL2bFj\nB+fOnaOmpobS0tKoG+d0JgwLubm5XLhwAYDp06fj8/lobm5my5YtAGzatImmpiYlxg0ODjI4OEhn\nZyednZ0kJSWRlJTE8PAww8PDGIaBYRhWC6lhGFZ7qYwFg0GCwSA+nw+fz0dhYSGFhYW43W7cbjcd\nHR10dHQosX9C5cbHx1sPgZqaGtavX8+jR4+sMJCWlqZsm5SsmYWr4/4bTp48GfXPDOWPA839+/ep\nqanh+vXrbNu2zRpX2R4kFSzJTfPz8wHwer3A+Cyhq6vLatqTBcpQFi9eDMCDBw8AqK6uBmD79u1R\nt/+PsoXGxkauXLnCtWvX8Hg8uN1uBgcHgZF9CRkZGVE3bDIwoXIHBgaorKzk5s2b1pak/Px86urq\n2LlzJ/X19axbt06JcXfu3AFg9uzZgJ3niirlta+vDxj5Fkm9wefzAVg7PeXet2/fAnZ28OTJE0CN\ncid0bm1tLV6vl/LycmvszJkzHD9+nOrqajIzM9m1a1fUDZsMTOjc4uJiiouLx43fuHFDiUGjefXq\nFWArU1QnDzjZMiU12paWFqvFXx7Csi01dAurbF99/vy5Mvv1DE0h/+vaQlFREQB3794FxtcSQrf4\ni1rBjqkSawWJ0/Kz0o6qAq1chRhmLPvYI0SUKk10kj3IRKa1tRUYidHZ2dkALFmyBIDe3t4x90r8\nlg3W3759G/N+NNHKVcj/OuYKsu6VmZk5Zjy0pbSnp8dSrmQQUgWTWZwoV7IElS2lWrkKcYRypV4g\njwfZeCKv9+7dA2D58uXWPXIGxJcvXwAoKysD1GYHoWjlKsQR2YJT0cpViHauQrRzFaKdqxDtXIVo\n5ypEO1chMZmhOfUg49BOo4aGBtra2sYcb7hx48bwH6D6gF6nHmTc1NRk7tu3zzRN0+zt7TU3bNhg\nHj161GxoaPjjz1Cu3HAHGas8WDMa5ObmWt8w6TQK3YMxEcpjrhMPMobfdxrFx8dz69YtysrKOHTo\nkFWID0fMq2Kmw0oZozuNWltbSUlJIScnh6tXr3L58mVOnDgR9l7lynXyQcahnUZ5eXlWa9XmzZut\npf9wKHduQUEBdXV1ADE5yDhaSKdRVVWVlR0cPHjQOiC5ubnZWvUIh/Kw4NSDjH/XabRnzx7Ky8tJ\nTk7G7XZz+vTpf/wMXc9ViJ6hKUQ7VyHauQrRzlWIdq5CtHMVop2rkL8Ammm37NJ/8tgAAAAASUVO\nRK5CYII=\n",
            "text/plain": [
              "\u003cmatplotlib.figure.Figure at 0x7fb7d3f78810\u003e"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Predicted class: 6, confidence: 89.767%\n",
            "Actual class: 6 \n",
            "\n",
            "\n"
          ]
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAFcAAABYCAYAAACAnmu5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAACZhJREFUeJztnGlIFe8Xxz/Xe1u0LC21hSLa9z1p\ntz3oRZQFWVERJUVQoBQWFfUmKhKCFmijDWyxpF5EkRUS+EIvtFISpRTRQqZppmXmvc7/hf8z431u\nlnd06uZvvm/GmTvLM8fPc+ac8zwzDk3TNGxZopC/3YDmLNu4Fso2roWyjWuhbONaKNu4Fspl9sDd\nu3fz+PFjHA4HW7duZdiwYU3ZruYhzYTcbre2Zs0aTdM0raCgQFu0aJGZ0zR7mSI3JyeHmTNnAtC7\nd2/KysqoqKigbdu2Zv/BADgcjl/u9/XrVwAePnwIwMGDBwGIiIgAYMiQIbRu3RqAkpISAO7evQvA\n5MmTAUhJSQGgRYsWjWpLQ2TKuMXFxQwePFhf79ChA0VFRaaN29AbadOmDQCTJk3yWf5KW7ZssaQt\nDZFpn1tXmskMuj5Kvn37BsDJkycBuH79us/2du3aAVBZWQnA/fv3ASgrK9PPIWT27NkTqO1hABUV\nFQB07twZgOnTpwOwcuVKAJ38ppCpaCEmJobi4mJ9/ePHj0RHRzdZo5qLHJoJ7B48eMChQ4c4ffo0\neXl57Nq1iwsXLgR8cZVc8ZPiz4U6oUlodDqdPts7duwIQHl5ud9vVVVVQC0AAB6PB4AfP34A8P37\ndwDdpSUlJQEwduzYgO9HlSm3MGrUKAYPHszixYtxOBzs3Lmz0Q1pjjJFrlVat24dAEVFRUDtgxKg\nuroaMAhXCW7VqpW+VIkVPyykqrdbU1Pj87tcKy0tDYCWLVuavh87Q7NQTRItNFbyBH///j0A7du3\nBwyaXK7aZkq0IPGuUBcSUsuI0+nUaRZfKuTKOYREWRdfK8RLW3JzcwGIi4szfV82uRYqKMgVEl++\nfAlAv379ACOOFfqEOiFYiBU/6vV69W3iO+U3IVpoVzO0bt26+Vzz2rVrgE1u0CooyM3PzweMGFQI\n9nq9Pkvxj7169QKgb9++APTp0weozdxCQ0MBI1UWQiV6uHfvHoAel4eHhwNQWFgIQGlpqc85GyOb\nXAsVVHFueXk5AFevXgWM6tfWrVsB6k2xxb96PB79byFV1iUakNhZJLWFgoICwOgdUnvIysoyfT82\nuRYqKHxuamoqYDz9Z8+eDcDo0aMBIzoYMWIEYDzxo6KiAIOyyMhI3cdKNicd8/PnzwC43W4Ahg8f\nDsClS5cAI8sTny1xcGNkk2uhgsLnPnv2DICbN28CRgXr8uXLACQnJwMwY8YMAL58+QLA8+fPfdYd\nDodf1UuIlF4xcuRIwIgmzp0757Of+OSzZ88CtaMuYBAdiGxyLVRQkCtPbKGne/fugJG5vXnzBoDs\n7Gyf48S/SiTgcrn8KmdCskQPUjuQGrD49y5dugCwcOFCn2vPmjXL9H3Z5FqooCBXxsoyMzOB2pEO\ngCVLlgAQGxsLwLt37wAjM1P969evX/VtIiFY4lfpHRI9SHy7bds2AOLj4wG4cuUKYPQaqT0EIptc\nCxUUce6TJ08A44kstYNp06YBcPv2bQAePXoEGDGsSqnT6fSLb2UpmZpUx4TExMREwBiml1FiGUvr\n1KmT6fuyybVQQUHuixcvAIMuWZdagsSkUs+VeQuSqUk2VbeeqxIsWZ744A8fPgBGb/n06RNg+FiZ\nOiDRRWRkZMD3FRTGFSOJEaXryrqERWI4tRQpBqypqfEroKvHyHZxKTIVSiSlR3lIitHNGNd2CxYq\nKMhVqRLaJHySoRd1qEadBuVwOPzcgToML65HrilEyrXExQjZ0mvMyCbXQgUFuSI1XJIHlzyMVKmk\nezwenWp1Kef8WfhWdz8hWgiW7WZkk2uhgoJc1XcKkVKQUaczqRFBXcpkH/WcKolqciHJgtpLpOBj\nRja5FiooyP2dZGqSSqzqV8GfWDlG9pFERHyrxLMDBgwAjGlMklw0pq5lk2uhGkTuvn37uH//Ph6P\nh7Vr1zJ06FBSUlLwer1ER0eTmpraqKmWEmvKcI1QJRI/KOVC1feKHA6HX8FGzqUW0VX6ZaK1FOTD\nwsJ+2pZA9Fvj5ubmkp+fT3p6OqWlpcTHxzN+/HiWLl3KnDlz2L9/PxkZGSxdutR0I5qrfusWYmNj\nOXDgAFAbd1ZWVuJ2u/XBwmnTpumDeIHK6/Xi9Xr1zErWo6Ki9GFzqH1iV1VVoWmajw+U46qrq6mu\nrv7pmzgejwePx6MfGxISQkhIiN81+/fvT//+/fVruVwuXC6XfrwpBfLS2sWLF7VNmzZp48aN07e9\nfv1aS0hICOQ0/xk1OFq4c+cOGRkZnDp1Sh/U+/8/x9x/FcOfLVu2DDDKgDJYeP78ecAYUpdBRdW/\n122DpmRtarQgsfOrV68AOHLkCGBkanPnzgWMgv3GjRsBo3AfiBoULWRnZ3P06FFOnDhBeHg4YWFh\nenhUWFhITExMwBf+L+i35JaXl7Nv3z7OnDmj1z4nTJhAZmYm8+bN49atW/qrn2YlvlJIVgcD1SqY\nZFWyXXyi0Af+r2HJb2pPk8l/Ei1IG9RMzox+a9wbN25QWlqqjykB7N27l+3bt5Oenk7Xrl2ZP3++\n6QY0Z/3WuAkJCSQkJPhtP336dJM1QiVXKBJJtiTuR6pl6mQ5p9PpR5y6FHcmS6nXyjllP7X+a0Z2\nhmahgqK2oPo1dVxLRiKENsmyZBBRCHa5XH7n0pRMTF7Dkkkhb9++9TmHOmnarooFqf4quWqlSpYq\nfcuXLwcM2iQOliihru+VbaqvlXhXqJfeMWbMGJ9r1e0FddtoRja5FuqvkiuUqeNXMldAtGrVqj/W\nJolcpE11vysRqGxyLdRfJVfi10GDBgHQo0cPwP/bNaoPbsrv0Khav349YEwtlemrZmSTa6GCYvJz\nc5VNroWyjWuhbONaKNu4Fso2roWyjWuhbONaqD+Sof2rHzJWZxplZWWRl5enV9RWr17N1KlT6z+B\n1WP3/+qHjHNycrTExERN0zStpKREmzJlirZ582YtKyurweewnNym/pDxn1JsbKzew2SmUaDzxiz3\nucXFxT6vGcmHjINdTqdTn4yXkZFBXFwcTqeTtLQ0VqxYQXJysv7V1Pr0x6ti2j9Wyqg70+jp06dE\nREQwcOBAjh8/zuHDh9mxY0e9x1pO7r/8IWN1ptH48eMZOHAgUPuNCHnTsz5ZbtyJEyfqr/rn5eUR\nExMT9P4WjJlGx44d06ODDRs26K+vut1u/dME9clyt/Cvfsj4ZzONFixYQFJSEqGhoYSFhbFnz55f\nnsOu51ooO0OzULZxLZRtXAtlG9dC2ca1ULZxLZRtXAv1PzbxYRggIQtTAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "\u003cmatplotlib.figure.Figure at 0x7fb7cbc7a090\u003e"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Predicted class: 1, confidence: 99.995%\n",
            "Actual class: 1 \n",
            "\n",
            "\n"
          ]
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAFcAAABYCAYAAACAnmu5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAABxNJREFUeJztnE1IVF0Yx3/j5KQz+ZGTIxUFZQUi\ntShcaPRBRFCbyEWIizZ9rQqFKIioRZBh0UYXljQUBCUMUZuwkAkqsFmEBUpQRt+UZU5mpVMz3Xcx\nPHf05jTTvB51pvPbXO713jPn/v3f5zzn3HOPzTAMA40Ssqa6ApmMFlchWlyFaHEVosVViBZXITNS\nvfDEiRM8evQIm83G4cOHWbFixUTWKzMwUiAQCBh79uwxDMMwent7je3bt6dSTMaTUljo7Oxk48aN\nAJSWljI4OMjXr18n9J+eCaQkbn9/P7Nnzzb3i4qK+Pjx44RVKlOYkAbN0D3ocUlJXI/HQ39/v7n/\n4cMHiouLJ6xSmUJK4q5evZqbN28C0NPTg8fjYdasWRNasUwgpVRs5cqVlJeXU1NTg81m49ixYxNd\nr4zAZuiAqQzdQ1OIFlchWlyFaHEVosVViBZXIVpchWhxFaLFVUjKbyKmAmtn0mazAfDr1y9z33os\nK2t8/0hZcn48IpGIWU6ic61o5yokI8YWErl0NNevXwdg//79ALx8+VJZvbRzFZKWzk0mXt65cweA\nQCAAQGtrK4A57jxz5kwAVq1aBUBzc/O45UjM9Xq97N69+6/qqZ2rkLRwbiKnfvr0CYBr166Zjr1y\n5QoA5eXlAMyfPx+AwsJCAB4/fgzAs2fPAAgGg+OWffv2bQAuXLjAxYsX/6re2rkKmdI81+pIiW92\nu33MeVbHjoyMANDQ0ABAS0sLAC6Xi9LSUgA2bdoEQCgUAjDnVYiTe3t7gd9j7bdv3wBob28HoKam\nBoC+vj4AXr9+DcCCBQsS3p92rkKmJObKT8o2UX4q8fHy5ctAzG0lJSUALFu2DIDs7Gwz/g4ODgKQ\nl5cHQH5+PhCLreK8t2/fAjEnO51OIPqGG6KxFiAcDpOdnW062+VyJbzPadmg+Xw+IDrZD2KP4pIl\nSwBwu91AVEyAFy9emNdaQ0pOTg4Ac+bMGbP/+fNnAL5//w7A0qVLATh//vyY6yWFO378OK9evTLD\njTSYRUVFce9DhwWFTGqDJg70+/0AdHd3A7EG6s2bN0DMVfLYi6tklo+4TR5RafAKCgrMsuSBFHf/\n+PEDiDlbOhPyeHd0dACwcOFCAL58+QLA4sWLgehcDYiGB4h2KgAOHDgQ9361cxUyKc69evUq1dXV\n1NfXA7G0SAZccnNzgVjaJI2P/F0aJ3FdQUEBgDk/TdyUlZVlulmcKy6Xc4aGhoCY6+U3JXbOmBGV\nxOPxADHny3UDAwNJ37d2rkImxblbtmwB4PTp00B08jRAV1cXEGvtJY16//49EIuT4ibpZEjsfvfu\nHRBL5cLhsOlEcaq4X5DUTLYygCMOlbIkFksMl26zPGVbt25NeN/auQqZlDx3ZGSEnJwc04kOh2PM\n38VlkuBL1vD8+XMgNqAtLhI3WgfJS0pKzHgsMVPit2QH1iFHa13k6bDK4nQ6cTgcSb8eAu1cpSTl\n3MbGRh48eEA4HGbv3r0sX76cgwcPEolEKC4u5tSpU785YDzEubK1/rQ4UOLf6CxgNFb3yH44HDbd\nLA6UMqy/LVvJXOQ8a3wfHh4GojG4oqKChw8fArEcfO7cuXHvN2GDdv/+fZ4+fUpbWxvBYJBt27ZR\nWVlJbW0tmzdv5syZM/h8PmpraxMV9c+R0LmRSIRQKITT6SQSiVBVVYXL5aK9vR2Hw0FXVxder5em\npqb/XZmfP38CsVgqVZPj4i7ZCqPPt7rc6mTJlaXVH32fo7dSzui47na7zXxXso0/kdC5drvdHCny\n+XysXbuWe/fumWHA7XZP2GdSEg6syGDLdCAZUYWkG7SOjg58Ph9Hjx4dc3waDqpNG5IS9+7du7S0\ntNDa2kpeXh5Op9NMi/r6+sy0RzOWhOIODQ3R2NjI2bNnzV5KVVWV+anUrVu3WLNmjdpapikJG7S2\ntjaamppYtGiReezkyZMcOXKEUCjEvHnzaGhoiBsv/2Wm5ZuITEH30BSixVWIFlchWlyFaHEVosVV\niBZXIVpchWhxFaLFVYgWVyFaXIVocRWixVWIFlchWlyFaHEVMimzHNN1IWPrTCO/309PT4/5LnHn\nzp2sX78+fgHqluaNkq4LGXd2dhq7du0yDMMwBgYGjHXr1hmHDh0y/H5/0mUod268hYyn+8KaFRUV\n5hOWn5/P8PCwORsnWZTH3HRdyHi8mUZ2u51Lly6xY8cO6uvrE07hn/TPU400e9ksM428Xi/d3d0U\nFhZSVlbGuXPnaG5u/m0G0miUOzedFzK2zjSqrKykrKwMgA0bNvDkyZM/Xq9c3HRdyHi8mUb79u0z\nv8cIBALm93HxUB4W0nUh4xs3bhAMBqmrqzOPVVdXU1dXR25uLk6n0/xqPh56xo1CdA9NIVpchWhx\nFaLFVYgWVyFaXIVocRXyH3sBm5Xz4vLQAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "\u003cmatplotlib.figure.Figure at 0x7fb7cbc8a890\u003e"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Predicted class: 4, confidence: 88.968%\n",
            "Actual class: 4 \n",
            "\n",
            "\n"
          ]
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAFcAAABYCAYAAACAnmu5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAACVxJREFUeJztnFlIVd8Xxz83rcyybNDCKPohTYgF\nRYPNRUU+NVASEQ0U9VJQEEURRRAVFj00QAPYQAOCDz1FRvnQZAZNkAQpDTSXZIOV6b2e/0P/7zne\nnbc72OmanM/LdR/PsM+637v22muvc3yWZVl4uEKbeHegNeMZ10U847qIZ1wX8YzrIp5xXSQx1gN3\n7NjB/fv38fl8bNq0iSFDhvzJfrUOrBgoKyuzVqxYYVmWZVVWVlp5eXmxnKbVE5NyS0tLmTp1KgCZ\nmZl8+vSJmpoaOnXq1NwvOqjt8/l+u/+jR48A6NWrFwCdO3f+ZZ+PHz8C8PDhQwBycnKa1cdoiMm4\nVVVVZGVl2e1u3brx/v37Zhs3nDFNBgwYEHaf1NRU4O8aVcTscxtjKq655zGN/PnzZwDOnDkDwMaN\nGwFHlZGQkJAAQNu2bQE4efIkAPPmzYuqL9EQU7SQnp5OVVWV3X737h1paWkxd6K14rNikN2dO3fY\nv38/x44do7y8nO3bt3P27NmYOxFKJRMmTADg7t27ANTW1gLY7iclJQWAb9++AZCWlkZ6ejoAT548\nAaCmpibomK9fvwLOr0GimDt3LgD79u2LqG+REJNbGDZsGFlZWcyfPx+fz8fWrVtjOU2rJybl/ilC\nqWLGjBkAXL58GYDevXsD8OPHj6D9A4FAULumpoaGhgYAOnbsCEBi4k/91NXVBW0X2v727VsAli5d\nCkBBQcEvfY1Wvd4MzUXiqlyTW7duATB69GgA+vXrB2Cr8cuXLwC0adMm5KduR8cIbff7/UHHKHro\n0KEDAG/evAHgxo0bAAwdOjTm+/GU6yJxUa5UJfXYnfm/T+vZsycA9fX1AHTv3h1wlKvtil0bny+c\nXzT9vD6lYJ3r5cuXQddUtBGq703hKddF/sgMLVrMb10jdLdu3QAnR1BZWQk4yjX9ZCyYipX65YsV\nOycnJwNw7949AMaNGxd0XCR4ynWRuCjX5MKFC0FtxbPyj4pVhak+EcnwYR5rXkPX/v79OwDXr18H\nPOW2OFpEnJuZmQlAdXU14KQJlR8YNGgQ4OQWNKuS720cu8ov67Y0izNjYqHIQzM3KffDhw8AdOnS\nBYAXL15EfV+ecl0krj731atXgDOvl2Ll79q1awc4WS+N4Io1NdJLpY3j3FDxrLbr04xvlUpt3759\n0P9jwVOui8RVufKHpl/UCoNGcKnMjCKktsbtUAo1fa2uKYWqLaXqV/H06dOY789TrovEVbkPHjwA\nnNFfqpMipSqN5J8+fQIcXywa+9VQsa+ZA9an0DW1UpGRkQE4M7anT5/aWbpI8ZTrInE1rmVZWJZF\nQ0MDDQ0Ndpzarl27X9TZeH/9PxAIEAgEbMX6/X77XH6/P6gttG9tbS21tbUkJSWRlJRkt3XOxMRE\nEhMT7e179uyJ+v485bpIXH2uWURi+lozBjVXF9SOZL5vxrWamSmKUGzdtWvXJvui/0dDXKe/I0aM\nAOD+/fsA9OnTB3CWwxWSqVxJkwkZQG2FTYFAIOjvxsiIGhyVDNfSuiY0aivs03Fv3ryx3RJE9oV6\nbsFF4uoWVByXlJQEOIkZ/QRVC6allj9RYmT+3FUcIjehcysZpGv2798/6mt7ynWRuChXSlTArgVJ\ntaUqcwHS9Kdqyy/W19f/MuhpCq19dKyW0pVaVOin5JDSnxoAHz9+HPV9esp1kbgoV9GBkAI1QpvT\nXqUkpSIz8d3YD0bqE6VUhYNm0kg+V0o2o49I8JTrInFRrpmAVrQgdchvSrFqS21mirLxJMNMIQrt\no+26pnyv2u/evQMc5ZqLo9HgKddFIvpa8vPzuX37Nn6/n5UrV5Kdnc369esJBAKkpaWxe/fuJhMt\nocjOzg5qmwUaUovUprbUJVXpOMWofr/f9pFmktxUrtB2pT3Nxc9IypZCEda4N2/epKKigsLCQqqr\nq5k9ezY5OTksWLCA3Nxc9u7dS1FREQsWLIi5E62VsMYdMWKE/QBf586d+f79O2VlZWzbtg2AyZMn\nU1BQEJVxb968GdSW35O/VBmnEtamHzWVK2W3b9/+l2SO9jW3KzKR0qX+UDF0LIQ1bkJCgt2BoqIi\nJkyYwLVr12w30L17d96/fx/VRadMmQL8uaeAWioRD4WXLl2iqKiIgoICpk+fbm+PxUCrVq0C4ODB\ngwAMHz4ccAovFCWUlZUBzpchtalQQ+qSsn0+3y/5B3MxU/kLtTUTmzVrFgDnzp0DnKJALfM8e/YM\ny7LsWaRZ/t8UEWn+6tWrHDp0iKNHj5KSkkJycrLdybdv39pP0HgEE1a5X758IT8/n+PHj9tFG2PG\njKG4uJiZM2dy8eJFxo8fH9VFTQVqVmT6t5EjRwLY/l1K10KlYtK+ffsCjiobn0sKlhvTsfLT06ZN\nA5wHB6VcHa9fh7hy5QoAubm5Ye8zrHHPnz9PdXU1a9assbft2rWLzZs3U1hYSEZGhv2T8ggmLisR\nK1asAODo0aMA/Pfff4CjvNevXwOR+3NFBHV1dSFzvqZ/DoWOU8G1fGtVVRVfv35lyZIlABw/fjxs\nv7wZmovERblSqOJbDYiKMc3M1N9k4MCBgLOmpli7oqKChoYGTpw4AcDixYvDnstTrovEJSumHIEU\nqgfp1FYcHA6z2K6pH6Hpe0OVlKqdl5cH/By0wYkuFi5cCESmWOEp10XiuvqrOFdxr1Ty/PnzoP2U\nqfpdAV6smModNWpU0LW0Orxu3bqoz+0p10Xiqlyp5vTp04BTtt+jR4+g/ZqzGhAtqrhRskpxcSzZ\nMU+5LtIiHpVqrXjKdRHPuC7iGddFPOO6iGdcF/GM6yKecV3kr0x9/tUXGZuVRiUlJZSXl9szyWXL\nljFp0qTQJ2juC3jD8a++yLi0tNRavny5ZVmW9eHDB2vixInWhg0brJKSkojP4bpy3XqRsds0VWkU\nbY2u6z63qqrKfrQJnBcZt3SaqjRKSEjg1KlTLFq0iLVr19ol/6H461kx6x9LZTSuNHrw4AGpqakM\nHjyYI0eOcODAAbZs2RLyWNeV+y+/yNisNMrJyWHw4MHAzxIrvTM9FK4bd+zYsRQXFwNQXl5Oenp6\ni/e34FQaHT582I4OVq9eba+SlJWV2c+mhcJ1t/Cvvsi4qUqjOXPmsGbNGjp06EBycjI7d+787Tm8\nfK6LeDM0F/GM6yKecV3EM66LeMZ1Ec+4LuIZ10X+B6GVD2y+ED5zAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "\u003cmatplotlib.figure.Figure at 0x7fb7d0b69dd0\u003e"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Predicted class: 6, confidence: 95.252%\n",
            "Actual class: 6 \n",
            "\n",
            "\n"
          ]
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAFcAAABYCAYAAACAnmu5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAACB1JREFUeJztnFlIVO8bxz+jmWhpZo3tBFmIiF0E\nlmbZYjeGtBEpkl6UFEGCgVREtNBuJEQLWiQFEQleBC1khUFGJmRl6YVp0qpppqmVLer5X8h7js3P\ncc7M37ec6f3c2Mx43vP0+J3nvM9yjkXTNA2FFLz+tgGejHKuRJRzJaKcKxHlXIko50pkmKsHHjx4\nkIqKCiwWCzt27GDmzJmDaZdnoLlAWVmZtmHDBk3TNK22tlZbs2aNK8t4PC6FhdLSUpYsWQJAaGgo\nbW1tfPnyZVD/6J6AS85tbm5m9OjR+uvg4GA+fvw4aEZ5CoNyQdNUBt0vLjk3JCSE5uZm/XVTUxNW\nq3XQjPIUXHJubGwsRUVFAFRVVRESEsLIkSMH1TBn0DQNTdPo6enR/23vd+x9bktdXR11dXX6axH2\nzB4PLm7FZs2aRUREBMnJyVgsFnbv3u3KMh6PRfOAgCn+C5qm4eXl2mWkuroagMrKSqD3Gwnw5MkT\nfe0rV67Q1dUFwLBhjnWpMjSJuJVyhakWi8X0MXfv3gUgMjISgPLycgD27NkDwLRp0wC4f/8+0Bvy\nAGbPng1AWloaAJMmTXLaXqVciXiccpuamgDo7u4GICcnB4CAgAAAPn36BEB0dDQA8+bNA2Dy5MkA\nvHnzBoCKigoA5syZA4CPjw/BwcFO2auUKxG3Uq49fv36BUBNTY2uLj8/PwA92Tly5AhgKHTXrl0A\nfP36FYARI0YA8OPHDwBu374NQGdnp75eYmKiU3Yp5UrE5Xru36CnpwcwYq74+fPnT6A3LtbW1gJw\n8+ZNAA4cOADAs2fPAGPXIBCKFQgljxkzBjBi8PHjx0lMTOTDhw8AjB8/3qG9SrkScSvl2ipW0Ddb\nun79OgDr1q0DYP/+/U6do6OjA4D29nbA2Pdu3rwZMOJ731hsD6VciXjEbmEghNKEum1Vb7t3fvTo\nEWDsk6dMmQKAv78/QUFBeqXMx8fnt8/7QylXIm4Vcx3R09OjK9D2C2m2LiH2xaI+LY4TO5Vv374B\nRsY3EEq5EvEo5fat5QqFiveE8gT2lCx2C6dOnQJg9erVAKxcuRLojb0Aw4cPd2yPc+YrnGFIK9eV\n+q097CnYdu1x48YBEBUVBcCDBw8ASEpKAoyMzcyEkVKuRIa0cgdDsbbY67G9f/8egBkzZgBGnffx\n48dAbwcjNjZWr5qNGjXK4bk8PokQiHBg69wLFy4ARus8PT0dMNpBYrIoIiICX19fGhoaAJgwYYLD\nc6qwIJEhHRYGE6HYz58/A7Bv3z4AvVUuSoii8BMWFgYY6XN7eztWq9VUS10/5yDYrbDDkFCuvXjo\n6jpgFF5EgUU0LrOysgAIDw8H4NWrVwBs374d+O9FVFzo6uvrsVqtptJegVKuRIaEcm0Va7uBMbsl\n8/Ly0tUrFCvSWdFiT0hIAKCkpASAvLy8AdcU57Zd15Q9pn9T4TRDQrm2OJs89E2Tbb8FokEpitqi\nGH769GmnbGlpaQHA29vbtF1KuRIxpdzs7GzKy8vp6upi48aNREZGsnXrVrq7u7FarRw9etRUCc4e\ntgWa79+/A0a8FKoRe09b+lO6KBmKGCnS2HPnzvW7hm1Bx7bQI1rqzuDQuQ8fPqSmpoaCggJaW1tZ\nuXIlMTExpKSkkJCQQE5ODoWFhaSkpDh9ck/HoXOjoqL08lpgYCCdnZ2UlZWxd+9eABYtWkR+fv7/\n5Vxb5dXX1wPoebxouYhsaaArtlD7vXv3AKMtc/XqVZdsEa9fvnxp6vjfcOamtcuXL2tZWVladHS0\n/t7r16+1pKQkZ5b5ZzB9Qbtz5w6FhYX6AFufP46ZP6DDz82sY4bU1FRSU1MJDQ0lNDSUhoYG/Rtg\nBltbxPHJyclO22LKuSUlJeTm5nL27FkCAgLw9/fXLzqNjY2EhIQ4feJ/AYcxt6Ojg+zsbM6fP09Q\nUBAAc+fOpaioiOXLl3Pr1i3mz58/4BqO9q32BjXWrl0LGPEuOzsbgLi4uH7Xyc3N5dKlS4AxxmSm\n7joQYrcgahPO4NC5N27coLW1lczMTP29w4cPs3PnTgoKCpg4cSIrVqxw+sT/An+kE1FdXU1YWJg+\nCiSG18QuQFz9RfYjaqaiK1BQUAAYTcNr164BxvinuL0pPj5e/xaJY5zJqPrj7du3AGzatEk/r1lU\nhiaRP1JbePfuHWFhYfpAcmNjIwBtbW2AodyxY8cCRnY0depUADIyMgBDuaK/VVxcDMDTp08BWLp0\nKceOHQMMxTpzU15/+Pr6ArBs2TKnj1XKlciQ6P6KQeLW1lbAGDwW+bwwsaamBoDnz58DhvLj4+OB\n3pGjwMDAQbVN9NymT5/+2536ZlDKlciQUK47cOLECT32m0UpVyJKuRJRypWIcq5ElHMlopwrEeVc\niSjnSkQ5VyLKuRJRzpXIH6nnuuuDjG0njYqLi6mqqtJ7ievXr2fhwoX2F5Ddu3fXBxmXlpZq6enp\nmqZpWktLi7ZgwQJt27ZtWnFxsek1pCvX3oOM/+aDNc3Q36SRmFY3i/SY664PMvb29tbv8y0sLCQu\nLg5vb28uXrxIWloaW7Zs0QcE7fHH53M1NyvCiUmj/Px8KisrCQoKIjw8nDNnznDy5Mn/TCD1Rbpy\n3flBxraTRjExMfqNKosXL+bFixcDHi/duUPtQcZmEZNGeXl5+u4gIyNDn2MoKyvTb2W1h/Sw4K4P\nMu5v0mjVqlVkZmbi5+eHv78/hw4dGnAN1YmQiMrQJKKcKxHlXIko50pEOVciyrkSUc6VyP8AwKQj\nU8DNHBgAAAAASUVORK5CYII=\n",
            "text/plain": [
              "\u003cmatplotlib.figure.Figure at 0x7fb7d415afd0\u003e"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Predicted class: 5, confidence: 99.869%\n",
            "Actual class: 5 \n",
            "\n",
            "\n"
          ]
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAFcAAABYCAYAAACAnmu5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAACH5JREFUeJztnFlIFV8cxz9qC1qaLS7ZvphJFGUF\nStGiUdSDZVGKQRGVPkSoEEnRQiAtFkHLgyWYD0KbvfRQtGCLgloYbRZkVGabZlqaaaXN/0HOjHe8\nmxePf6+cD4jM3Dlnzvzu9/7md37nN+OhaZqGQgqe//cA+jLKuBJRxpWIMq5ElHEloowrkX6uNjx4\n8CBPnjzBw8OD3bt3M2PGjO4cV99Ac4HS0lItKSlJ0zRNe/36tbZu3TpXuunzuOQWiouLWbJkCQCT\nJk3ix48f/Pz5s1u/9L6AS8atra1l6NCh+vawYcP4+vVrtw2qr9AtNzRNzaCt4pJxAwMDqa2t1bdr\namoICAjotkH1FVwy7rx587hx4wYA5eXlBAYGMnjw4G4dmCtomqb/uUpKSgopKSlUVVVRVVWl729t\nbe1yXy6FYhEREUybNo2EhAQ8PDzYv3+/K930eTw0N3CYYogeHh4O95svx9ymra0NAC8vLwA+f/4M\nwLJlywB49uwZAJs2bQIgJydHbyfaOIuaoUnELZQrsKXKf//+dTpG7Ovfv7/FftGmoqICgPnz51sc\nJ27M9+/fB8DX11dvb/4VOEIpVyJupdzuoKGhAYBRo0YBMGTIEMDwxdevXwdg5syZgG1/7wxKuRJx\nOSv2f+BIRU1NTTQ1NQHQ2NgIwKNHjwD48OEDYMSrwrdOmDABQJ++h4aGdtt4lXIl4lbKFRGAiDfr\n6uoASEtL07eFDy0uLgZgypQpAJSUlACwatUqi/3Nzc0A+Pj4WJzD2bF4etrWp1KuRPpEtCDU5+3t\n7fDY4OBgAH79+gW050kAEhISANi4cSNgKFP4d03T8PT0dEqxAqVcibiVz7WFUKymaZ38spm4uDgA\nsrOzARgxYgQAt27dAgzlmpUpFCz2O/NrUcqVSJ9QbsfbhlmxZh+ZlJQEQFZWlsXnjx8/BuDv37+A\nkWsQVFdXExQUxJo1awAYPXo0ACdOnLA5LqVcifQJ5dqb95s/mzVrFmD4WhErBwYGAu0rKx23Y2Ji\n9LYvX77k3r17ALx7987huJRyJeLWynVmJcIWkydPBtAXWmtqagBD2aLPkSNHAjBgwADAmOE5s2ao\nlCsRt1ZuV3ytGZF7iIyMBOD9+/cAFBUVAUbeNzk5GTCybLGxsU6Pz62Na6bjUox5+tpxGgswcOBA\nAL1yyJY7Ee5ApCpFGZczKLcgkV6lXKEeWyoyq9Aeto6Njo4G0CcD58+ft9peLPuI/+Hh4YCRmnQG\npVyJ9KhyzaGTebsrynSEuY+tW7cCMHv2bACOHTtmd2zCZ4vSWHHj6wpKuRLpUeXaUqwtxKJhbm4u\nANu2bQM6+72O/Ym7er9+7Zd28uRJAL58+QIYqUZbYzNvC58bFhZm85y2UMqVSI8s84giNqECc+JZ\nqOvIkSMAjB8/3qL9ixcvABgzZgwAT58+tXqejnGumASIcqW7d+8C7VXw0Dm1aFai8LViLA8ePGDi\nxIkW5+p4vDWUciXilM/NzMykrKyM1tZWkpOTmT59Ojt37qStrY2AgACOHj2qz2SsIRLYtpZeHj58\nCMCnT5+AznfskJAQi89FoUdERIRFPx1VlJiYCMDatWsBQ7ECczLczPfv3wEjQdNRtc7i0LglJSVU\nVFRw8eJF6uvriYuLIyoqisTERJYvX87x48fJz8/XL0Zh4NC4c+fO1R/g8/Pzo7m5mdLSUg4cOADA\n4sWLycnJsWvciooKQkND9USzWM4Wd/S3b99aHC8S2SLdJ9Tj7+8PwPr164H25LUZUbQsEjOXL192\ndIlWEWVRooTUjFMRT1ceWrtw4YK2Y8cOLTIyUt9XWVmpxcfH223X0tLSldP0GZyOc2/fvk1+fj45\nOTksXbq00zdoj6KiImJiYvRM1Lhx4wAjI/XmzRsAfv/+DRjz+OrqagAqKysBw2eLqCMzMxMw1Jqe\nns6VK1cAw9faimvNaCYlFhYWArBv3z4A7ty5Y/d4azgVLRQWFpKVlUV2dja+vr74+PjQ0tICtBtA\nrDcpLHEY5zY2NpKYmEhubi7Dhw8HYO/evcyZM4eVK1eSkZFBWFiYrhR7rFixAjCK4gRCkeJLErMs\nER0MGjQIMEqQxJDNj8QGBwfrSiorKwMMv+1IaebPr169CkBeXh4Aly5dcnh9Zhy6hWvXrlFfX09q\naqq+7/Dhw+zZs4eLFy8SEhKiryspLHFo3Pj4eOLj4zvtP3fuXJdPJhYFxQOCovBYzJY+fvwIGD7V\nz88PQHdBYkYn4l8RVXQsKRo7dixgKNZZxK9FxL9iyV0U7pmPE5GOPdQMTSI9mhXbtWsXAKdOnQKM\nwgrh74RSxWxKRBfCJ//588fiv1CwUFNDQwM3b960OKczd3VrCP8eFBRktT9nUMqVSI8q13znFtkt\nUXZfUFAAGPP6rhIbG8vUqVNdamtWtnjITzxS5QpKuRLplWX7osRI5A5E1kw8ov/t2zfAUJvI82Zk\nZOh9dNXXmo8XOWRRKiruB6psv5fQK5XbV1DKlYgyrkSUcSWijCsRZVyJKONKRBlXIsq4ElHGlUiP\nZMXc9UXG5kqjgoICysvL9fqJzZs3s2jRItsdyF67d9cXGRcXF2tbtmzRNE3T6urqtIULF2rp6ela\nQUGB031IV66tFxn3hhdr2sNapZFY23MW6T7XXV9k7OXlpRdZ5+fns2DBAry8vMjLy2PDhg2kpaXp\ni5i26PGneTQ3S8J1rDR6/vw5/v7+hIeHc/bsWU6fPq1X5FhDunLd+UXG5kqjqKgovdQqOjqaV69e\n2W0v3bi99UXGjmhsbCQzM5MzZ87o0cH27dv11ZDS0lKHL3iT7hbc9UXG1iqNVq9eTWpqKt7e3vj4\n+HDo0CG7faiVCImoGZpElHEloowrEWVciSjjSkQZVyLKuBL5D1WHWidFCGkdAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "\u003cmatplotlib.figure.Figure at 0x7fb7ccf96490\u003e"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Predicted class: 7, confidence: 99.997%\n",
            "Actual class: 7 \n",
            "\n",
            "\n"
          ]
        }
      ],
      "source": [
        "predictions = estimator.predict(input_fn=input_fn(\"predict\", training=False, batch_size=1))\n",
        "\n",
        "for i, val in enumerate(predictions):\n",
        "    predicted_class = val['class_ids'][0]\n",
        "    prediction_confidence = val['probabilities'][predicted_class] * 100\n",
        "    \n",
        "    # Display the image\n",
        "    plt.imshow(x_test[i])\n",
        "    plt.rcParams['figure.figsize'] = (1,1)\n",
        "    plt.show()\n",
        "    \n",
        "    print('Predicted class: %s, confidence: %s%%' % (predicted_class, round(prediction_confidence, 3)))\n",
        "    print('Actual class: %s \\n\\n' % y_test[i])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "TKhCzP65hGyS"
      },
      "source": [
        "## Conclusion and next steps\n",
        "\n",
        "In this tutorial, you learned how to customize `adanet` to encode your\n",
        "understanding of a particular dataset, and explore novel search spaces with\n",
        "AdaNet.\n",
        "\n",
        "One use-case that has worked for us at Google, has been to take a production\n",
        "model's TensorFlow code, convert it to into an `adanet.subnetwork.Builder`, and\n",
        "adaptively grow it into an ensemble. In many cases, this has given significant\n",
        "performance improvements.\n",
        "\n",
        "As an exercise, you can swap out the FASHION-MNIST with the MNIST handwritten\n",
        "digits dataset in this notebook using `tf.keras.datasets.mnist.load_data()`, and\n",
        "see how `SimpleCNN` performs."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "customizing_adanet.ipynb",
      "provenance": [],
      "version": "0.3.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
